Skip to content

Commit 5540ac3

Browse files
committed
Fix code and add Unit Tests for CertificateCache
1 parent 6b3455b commit 5540ac3

File tree

4 files changed

+110
-27
lines changed

4 files changed

+110
-27
lines changed

Directory.Build.props

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
<Copyright>© Microsoft Corporation. All rights reserved.</Copyright>
88
<PackageIcon></PackageIcon>
99
<PackageIconFullPath></PackageIconFullPath>
10-
<LangVersion>12.0</LangVersion>
10+
<LangVersion>13.0</LangVersion>
1111
<PackageLicenseExpression>MIT</PackageLicenseExpression>
1212
<StrongNameKeyId>Microsoft</StrongNameKeyId>
1313
<EmbedUntrackedSources>true</EmbedUntrackedSources>

src/Kubernetes.Controller/Certificates/ImmutableCertificateCache.cs

Lines changed: 49 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -34,27 +34,9 @@ public ImmutableCertificateCache(IEnumerable<TCert> certificates, Func<TCert, IE
3434
_wildCardDomains.Sort(DomainNameComparer.Instance);
3535
}
3636

37-
public bool TryGetCertificateExact(string domain, [NotNullWhen(true)] out TCert? certificate) =>
38-
_certificates.TryGetValue(domain, out certificate);
39-
40-
public bool TryGetWildcardCertificate(string domain, [NotNullWhen(true)] out TCert? certificate)
41-
{
42-
if (_wildCardDomains.BinarySearch(new WildCardDomain(domain, null!), DomainNameComparer.Instance) is { } index and < -1)
43-
{
44-
var candidate = _wildCardDomains[~index];
45-
if (domain.EndsWith(candidate.Domain, true, CultureInfo.InvariantCulture))
46-
{
47-
certificate = candidate.Certificate;
48-
return true;
49-
}
50-
}
5137

52-
certificate = null;
53-
return false;
54-
}
5538

56-
public TCert? GetDefaultCertificate() => _wildCardDomains.FirstOrDefault()?.Certificate
57-
?? _certificates.Values.FirstOrDefault();
39+
protected abstract TCert? GetDefaultCertificate();
5840

5941
public TCert? GetCertificate(string domain)
6042
{
@@ -70,7 +52,35 @@ public bool TryGetWildcardCertificate(string domain, [NotNullWhen(true)] out TCe
7052
return GetDefaultCertificate();
7153
}
7254

73-
private record WildCardDomain(string Domain, TCert Certificate);
55+
protected IReadOnlyList<WildCardDomain> WildcardCertificates => _wildCardDomains;
56+
57+
protected IReadOnlyDictionary<string, TCert> Certificates => _certificates;
58+
59+
protected record WildCardDomain(string Domain, TCert? Certificate);
60+
61+
private bool TryGetCertificateExact(string domain, [NotNullWhen(true)] out TCert? certificate) =>
62+
_certificates.TryGetValue(domain, out certificate);
63+
64+
private bool TryGetWildcardCertificate(string domain, [NotNullWhen(true)] out TCert? certificate)
65+
{
66+
if (_wildCardDomains.BinarySearch(new WildCardDomain(domain, null!), DomainNameComparer.Instance) is { } index)
67+
{
68+
if (index > -1)
69+
{
70+
certificate = _wildCardDomains[index].Certificate!;
71+
return true;
72+
}
73+
// var candidate = _wildCardDomains[~index];
74+
// if (domain.EndsWith(candidate.Domain, true, CultureInfo.InvariantCulture))
75+
// {
76+
// certificate = candidate.Certificate!;
77+
// return true;
78+
// }
79+
}
80+
81+
certificate = null;
82+
return false;
83+
}
7484

7585
/// <summary>
7686
/// Sorts domain names right to left.
@@ -83,7 +93,21 @@ private class DomainNameComparer : IComparer<WildCardDomain>
8393

8494
public int Compare(WildCardDomain? x, WildCardDomain? y)
8595
{
86-
return Compare(x!.Domain.AsSpan(), y!.Domain.AsSpan());
96+
var ret = Compare(x!.Domain.AsSpan(), y!.Domain.AsSpan());
97+
if (ret != 0)
98+
{
99+
return ret;
100+
}
101+
102+
switch (x!.Certificate, y!.Certificate)
103+
{
104+
case (null, {}) when x.Domain.Length > y.Domain.Length:
105+
return 0;
106+
case ({}, null) when x.Domain.Length < y.Domain.Length:
107+
return 0;
108+
default:
109+
return x.Domain.Length - y.Domain.Length;
110+
}
87111
}
88112

89113
private static int Compare(ReadOnlySpan<char> x, ReadOnlySpan<char> y)
@@ -93,8 +117,8 @@ private static int Compare(ReadOnlySpan<char> x, ReadOnlySpan<char> y)
93117

94118
for (var i = 1; i <= length; i++)
95119
{
96-
var charA = x[^i] & 0x3F;
97-
var charB = y[^i] & 0x3F;
120+
var charA = x[^i] & 0x5F;
121+
var charB = y[^i] & 0x5F;
98122

99123
if (charA == charB)
100124
{
@@ -104,7 +128,8 @@ private static int Compare(ReadOnlySpan<char> x, ReadOnlySpan<char> y)
104128
return charB - charA;
105129
}
106130

107-
return x.Length - y.Length;
131+
//return x.Length - y.Length;
132+
return 0;
108133
}
109134
}
110135
}

src/Kubernetes.Controller/Certificates/ServerCertificateSelector.CertificateCache.cs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,25 @@
33
#nullable enable
44
using System.Collections.Generic;
55
using System.Formats.Asn1;
6+
using System.Linq;
67
using System.Security.Cryptography.X509Certificates;
78

89
namespace Yarp.Kubernetes.Controller.Certificates;
910

1011
internal partial class ServerCertificateSelector
1112
{
1213
private class ImmutableX509CertificateCache(IEnumerable<X509Certificate2> certificates)
13-
: ImmutableCertificateCache<X509Certificate2>(certificates, GetDomains);
14+
: ImmutableCertificateCache<X509Certificate2>(certificates, GetDomains)
15+
{
16+
protected override X509Certificate2? GetDefaultCertificate()
17+
{
18+
if (WildcardCertificates.Count != 0)
19+
{
20+
return WildcardCertificates[0].Certificate;
21+
}
22+
return Certificates.Values.FirstOrDefault();
23+
}
24+
}
1425

1526
private static IEnumerable<string> GetDomains(X509Certificate2 certificate)
1627
{
Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,56 @@
11
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

4+
using System.Collections.Generic;
5+
using Xunit;
6+
#nullable enable
47
namespace Yarp.Kubernetes.Controller.Certificates.Tests;
58

69
public class CertificateCacheTests
710
{
8-
11+
12+
private static readonly FakeCertificateCache Cache = new(
13+
new FakeCertificate("Acme", "mail.acme.com", "www.acme.com"),
14+
new FakeCertificate("Initech", "*.initech.com", "initech.com"),
15+
new FakeCertificate("Northwind", "*.northwind.com")
16+
);
17+
18+
[Theory]
19+
[InlineData("www.acme.com", "Acme")]
20+
[InlineData("www.ACME.com", "Acme")]
21+
[InlineData("mail.acme.com", "Acme")]
22+
[InlineData("acme.com", null)]
23+
[InlineData("store.acme.com", null)]
24+
[InlineData("www.northwind.com", "Northwind")]
25+
[InlineData("mail.northwind.com", "Northwind")]
26+
[InlineData("northwind.com", null)]
27+
[InlineData("initech.com", "Initech")]
28+
[InlineData("www.initech.com", "Initech")]
29+
[InlineData("www.IniTech.coM", "Initech")]
30+
public void CertificateConversionFromPem(string requestedDomain, string? expectedCompanyName)
31+
{
32+
var certificate = Cache.GetCertificate(requestedDomain);
33+
if (expectedCompanyName != null)
34+
{
35+
Assert.Equal(expectedCompanyName, certificate?.Name);
36+
}
37+
else
38+
{
39+
Assert.Null(certificate?.Name);
40+
}
41+
}
42+
43+
private record FakeCertificate(string Name, params string[] Domains);
44+
45+
private class FakeCertificateCache(params IEnumerable<FakeCertificate> certificates)
46+
: ImmutableCertificateCache<FakeCertificate>(certificates, static cert => cert.Domains)
47+
{
48+
protected override FakeCertificate? GetDefaultCertificate()
49+
{
50+
return null;
51+
}
52+
}
953
}
54+
55+
56+

0 commit comments

Comments
 (0)