Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/DnsClient/DnsMessageHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ internal enum DnsMessageHandleType
TCP
}

internal abstract class DnsMessageHandler
internal abstract class DnsMessageHandler : IDisposable
{
public abstract DnsMessageHandleType Type { get; }

Expand Down Expand Up @@ -170,5 +170,9 @@ public virtual DnsResponseMessage GetResponseMessage(ArraySegment<byte> response

return response;
}

public virtual void Dispose()
{
}
}
}
137 changes: 119 additions & 18 deletions src/DnsClient/DnsTcpMessageHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,22 @@ namespace DnsClient
{
internal class DnsTcpMessageHandler : DnsMessageHandler
{
private bool _disposedValue = false;
private readonly ConcurrentDictionary<IPEndPoint, ClientPool> _pools = new ConcurrentDictionary<IPEndPoint, ClientPool>();

public override DnsMessageHandleType Type { get; } = DnsMessageHandleType.TCP;

public override DnsResponseMessage Query(IPEndPoint server, DnsRequestMessage request, TimeSpan timeout)
{
CancellationToken cancellationToken = default;
if (_disposedValue)
{
throw new ObjectDisposedException(nameof(DnsTcpMessageHandler));
}

using var cts = timeout.TotalMilliseconds != Timeout.Infinite && timeout.TotalMilliseconds < int.MaxValue ?
new CancellationTokenSource(timeout) : null;

cancellationToken = cts?.Token ?? default;
var cancellationToken = cts?.Token ?? default;

ClientPool pool;
while (!_pools.TryGetValue(server, out pool))
Expand All @@ -32,7 +36,7 @@ public override DnsResponseMessage Query(IPEndPoint server, DnsRequestMessage re

cancellationToken.ThrowIfCancellationRequested();

var entry = pool.GetNextClient();
var entry = pool.GetNextClient(cancellationToken);

using var cancelCallback = cancellationToken.Register(() =>
{
Expand Down Expand Up @@ -69,6 +73,11 @@ public override async Task<DnsResponseMessage> QueryAsync(
DnsRequestMessage request,
CancellationToken cancellationToken)
{
if (_disposedValue)
{
throw new ObjectDisposedException(nameof(DnsTcpMessageHandler));
}

cancellationToken.ThrowIfCancellationRequested();

ClientPool pool;
Expand All @@ -77,7 +86,7 @@ public override async Task<DnsResponseMessage> QueryAsync(
_pools.TryAdd(server, new ClientPool(true, server));
}

var entry = await pool.GetNextClientAsync().ConfigureAwait(false);
var entry = await pool.GetNextClientAsync(cancellationToken).ConfigureAwait(false);

using var cancelCallback = cancellationToken.Register(() =>
{
Expand Down Expand Up @@ -281,6 +290,30 @@ private async Task<DnsResponseMessage> QueryAsyncInternal(TcpClient client, DnsR
return DnsResponseMessage.Combine(responses);
}


protected virtual void Dispose(bool disposing)
{
if (!_disposedValue)
{
if (disposing)
{
foreach (var entry in _pools)
{
entry.Value.Dispose();
}

_pools.Clear();
}

_disposedValue = true;
}
}

public override void Dispose()
{
Dispose(true);
}

private class ClientPool : IDisposable
{
private bool _disposedValue = false;
Expand All @@ -294,7 +327,7 @@ public ClientPool(bool enablePool, IPEndPoint endpoint)
_endpoint = endpoint;
}

public ClientEntry GetNextClient()
public ClientEntry GetNextClient(CancellationToken cancellationToken)
{
if (_disposedValue)
{
Expand All @@ -306,20 +339,54 @@ public ClientEntry GetNextClient()
{
while (entry == null && !TryDequeue(out entry))
{
entry = new ClientEntry(new TcpClient(_endpoint.AddressFamily) { LingerState = new LingerOption(true, 0) }, _endpoint);
entry.Client.Connect(_endpoint.Address, _endpoint.Port);
entry = ConnectNew();
}
}
else
{
entry = new ClientEntry(new TcpClient(_endpoint.AddressFamily), _endpoint);
entry.Client.Connect(_endpoint.Address, _endpoint.Port);
entry = ConnectNew();
}

return entry;

ClientEntry ConnectNew()
{
var newClient = new TcpClient(_endpoint.AddressFamily)
{
LingerState = new LingerOption(true, 0)
};

bool gotCanceled = false;
cancellationToken.Register(() =>
{
gotCanceled = true;
newClient.Dispose();
});

try
{
newClient.Connect(_endpoint.Address, _endpoint.Port);
}
catch (Exception) when (gotCanceled)
{
throw new TimeoutException("Connection timed out.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The caller Query method doesn't seem to translate cancellations to TimeoutException.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not exactly sure what you mean by that.
What I want to do here is allowing LookupClient does handle OperationCanceled- and Timeout-Exceptions and retry the request if possible.
For TCP, if the connection attempt takes longer then the configured timeout for a query, I want to handle it the same way and maybe have LookupClient retry.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed both, sync and async version to throw new OperationCanceledException("Connection timed out.", cancellationToken);
Still, its mostly a signal for LookupClient to retry

}
catch (Exception)
{
try
{
newClient.Dispose();
}
catch { }

throw;
}

return new ClientEntry(newClient, _endpoint);
}
}

public async Task<ClientEntry> GetNextClientAsync()
public async Task<ClientEntry> GetNextClientAsync(CancellationToken cancellationToken)
{
if (_disposedValue)
{
Expand All @@ -331,17 +398,55 @@ public async Task<ClientEntry> GetNextClientAsync()
{
while (entry == null && !TryDequeue(out entry))
{
entry = new ClientEntry(new TcpClient(_endpoint.AddressFamily) { LingerState = new LingerOption(true, 0) }, _endpoint);
await entry.Client.ConnectAsync(_endpoint.Address, _endpoint.Port).ConfigureAwait(false);
entry = await ConnectNew().ConfigureAwait(false);
}
}
else
{
entry = new ClientEntry(new TcpClient(_endpoint.AddressFamily), _endpoint);
await entry.Client.ConnectAsync(_endpoint.Address, _endpoint.Port).ConfigureAwait(false);
entry = await ConnectNew().ConfigureAwait(false);
}

return entry;

async Task<ClientEntry> ConnectNew()
{
var newClient = new TcpClient(_endpoint.AddressFamily)
{
LingerState = new LingerOption(true, 0)
};

#if NET6_0_OR_GREATER
await newClient.ConnectAsync(_endpoint.Address, _endpoint.Port, cancellationToken).ConfigureAwait(false);
#else

bool gotCanceled = false;
cancellationToken.Register(() =>
{
gotCanceled = true;
newClient.Dispose();
});

try
{
await newClient.ConnectAsync(_endpoint.Address, _endpoint.Port).ConfigureAwait(false);
}
catch (Exception) when (gotCanceled)
{
throw new TimeoutException("Connection timed out.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this method is async, it shouldn't be translated to TimeoutException.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll change both, the sync and async version to throw new OperationCanceledException("Connection timed out.", cancellationToken);

}
catch (Exception)
{
try
{
newClient.Dispose();
}
catch { }

throw;
}
#endif
return new ClientEntry(newClient, _endpoint);
}
}

public void Enqueue(ClientEntry entry)
Expand Down Expand Up @@ -432,11 +537,7 @@ public void DisposeClient()
{
try
{
#if !NET45
Client.Dispose();
#else
Client.Close();
#endif
}
catch { }
}
Expand Down
1 change: 0 additions & 1 deletion src/DnsClient/DnsUdpMessageHandler.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using System;
using System.Buffers;
using System.Collections.Concurrent;
using System.Net;
using System.Net.Sockets;
using System.Threading;
Expand Down
2 changes: 1 addition & 1 deletion src/DnsClient/ILookupClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,4 @@ public interface ILookupClient : IDnsQuery

#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member
}
}
}
30 changes: 27 additions & 3 deletions src/DnsClient/LookupClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace DnsClient
/// ]]>
/// </code>
/// </example>
public class LookupClient : ILookupClient, IDnsQuery
public sealed class LookupClient : ILookupClient, IDnsQuery, IDisposable
{
private const int LogEventStartQuery = 1;
private const int LogEventQuery = 2;
Expand All @@ -62,6 +62,7 @@ public class LookupClient : ILookupClient, IDnsQuery
private readonly SkipWorker _skipper = null;

private IReadOnlyCollection<NameServer> _resolvedNameServers;
private bool _disposedValue;

/// <inheritdoc/>
public IReadOnlyCollection<NameServer> NameServers => Settings.NameServers;
Expand Down Expand Up @@ -370,7 +371,7 @@ internal LookupClient(LookupClientOptions options, DnsMessageHandler udpHandler

// Setting up name servers.
// Using manually configured ones and/or auto resolved ones.
IReadOnlyCollection<NameServer> servers = _originalOptions.NameServers?.ToArray() ?? new NameServer[0];
IReadOnlyCollection<NameServer> servers = _originalOptions.NameServers?.ToArray() ?? Array.Empty<NameServer>();

if (options.AutoResolveNameServers)
{
Expand Down Expand Up @@ -427,7 +428,9 @@ private void CheckResolvedNameservers()
}

_resolvedNameServers = newServers;
var servers = _originalOptions.NameServers.Concat(_resolvedNameServers).ToArray();
IReadOnlyCollection<NameServer> servers = _originalOptions.NameServers.Concat(_resolvedNameServers).ToArray();
servers = NameServer.ValidateNameServers(servers, _logger);

Settings = new LookupClientSettings(_originalOptions, servers);
}
catch (Exception ex)
Expand Down Expand Up @@ -1787,6 +1790,27 @@ public void MaybeDoWork()
}
}
}

private void Dispose(bool disposing)
{
if (!_disposedValue)
{
if (disposing)
{
_tcpFallbackHandler?.Dispose();
_messageHandler?.Dispose();
}

_disposedValue = true;
}
}

/// <inheritdoc/>
public void Dispose()
{
Dispose(disposing: true);
GC.SuppressFinalize(this);
}
}

internal class LookupClientAudit
Expand Down
Loading