Skip to content

Commit

Permalink
Consider a trailing dot when resolve DNS with search domains (#5963)
Browse files Browse the repository at this point in the history
Motivation:

There was a report from LY internally where DNS resolver warned for
`NXDomain` unexpectedly.
```java
java.util.concurrent.CompletionException: java.lang.IllegalArgumentException: Empty label is not a legal name
	at java.base/java.util.concurrent.CompletableFuture.encodeThrowable(CompletableFuture.java:315)
        ...
	at com.linecorp.armeria.internal.client.dns.SearchDomainDnsResolver.resolve0(SearchDomainDnsResolver.java:99)
	at com.linecorp.armeria.internal.client.dns.SearchDomainDnsResolver.resolve(SearchDomainDnsResolver.java:88)
	at com.linecorp.armeria.internal.client.dns.HostsFileDnsResolver.resolve(HostsFileDnsResolver.java:130)
	at com.linecorp.armeria.internal.client.dns.DefaultDnsResolver.resolveOne(DefaultDnsResolver.java:89)
	at com.linecorp.armeria.internal.client.dns.DefaultDnsResolver.resolve(DefaultDnsResolver.java:81)
	at com.linecorp.armeria.client.endpoint.dns.DnsEndpointGroup.sendQueries(DnsEndpointGroup.java:155)
	at com.linecorp.armeria.client.endpoint.dns.DnsEndpointGroup.lambda$sendQueries$3(DnsEndpointGroup.java:173)
       ...
Caused by: java.lang.IllegalArgumentException: Empty label is not a legal name
  at java.base/java.net.IDN.toASCIIInternal(IDN.java:284)
  at java.base/java.net.IDN.toASCII(IDN.java:123)
  at java.base/java.net.IDN.toASCII(IDN.java:152)
  at com.linecorp.armeria.internal.client.dns.DnsQuestionWithoutTrailingDot.<init>(DnsQuestionWithoutTrailingDot.java:53)
  at com.linecorp.armeria.internal.client.dns.DnsQuestionWithoutTrailingDot.of(DnsQuestionWithoutTrailingDot.java:48)
  at com.linecorp.armeria.internal.client.dns.SearchDomainDnsResolver$SearchDomainQuestionContext.newQuestion(SearchDomainDnsResolver.java:190)
  at com.linecorp.armeria.internal.client.dns.SearchDomainDnsResolver$SearchDomainQuestionContext.nextQuestion0(SearchDomainDnsResolver.java:177)
  at com.linecorp.armeria.internal.client.dns.SearchDomainDnsResolver$SearchDomainQuestionContext.nextQuestion(SearchDomainDnsResolver.java:150)
  at com.linecorp.armeria.internal.client.dns.SearchDomainDnsResolver.lambda$resolve0$1(SearchDomainDnsResolver.java:103)
  ... 19 common frames omitted
```

The NX domain has a trailing dot and search domains start with a dot
(`.`).
As a result, `example.com..search.domain` was made and rejected by
`java.net.IDN`

Modifications:

- Remove a leading dot from the normalized search domains.
- Infix a dot when a hostname does not have a trailing dot.

Result:

DNS resolver now correctly adds search domains for hostnames with
trailing dots.
  • Loading branch information
ikhoon authored Nov 8, 2024
1 parent 1ae1e9f commit 460ea02
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ private CompletableFuture<List<DnsRecord>> resolveOne(DnsQuestionContext ctx, Dn
});
future.handle((unused0, unused1) -> {
// Maybe cancel the timeout scheduler.
ctx.cancel();
ctx.setComplete();
return null;
});
return future;
Expand All @@ -112,7 +112,7 @@ CompletableFuture<List<DnsRecord>> resolveAll(DnsQuestionContext ctx, List<? ext
final int order = i;
delegate.resolve(ctx, questions.get(i)).handle((records, cause) -> {
assert executor.inEventLoop();
maybeCompletePreferredRecords(future, questions, results, order, records, cause);
maybeCompletePreferredRecords(ctx, future, questions, results, order, records, cause);
return null;
});
}
Expand Down Expand Up @@ -140,7 +140,8 @@ CompletableFuture<List<DnsRecord>> resolveAll(DnsQuestionContext ctx, List<? ext
}

@VisibleForTesting
static void maybeCompletePreferredRecords(CompletableFuture<List<DnsRecord>> future,
static void maybeCompletePreferredRecords(DnsQuestionContext ctx,
CompletableFuture<List<DnsRecord>> future,
List<? extends DnsQuestion> questions,
Object[] results, int order,
@Nullable List<DnsRecord> records,
Expand Down Expand Up @@ -170,6 +171,7 @@ static void maybeCompletePreferredRecords(CompletableFuture<List<DnsRecord>> fut
// Found a successful result.
assert result instanceof List;
future.complete(Collections.unmodifiableList((List<DnsRecord>) result));
ctx.setComplete();
return;
}

Expand All @@ -181,6 +183,7 @@ static void maybeCompletePreferredRecords(CompletableFuture<List<DnsRecord>> fut
unknownHostException.addSuppressed((Throwable) result);
}
future.completeExceptionally(unknownHostException);
ctx.setComplete();
}

public DnsCache dnsCache() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ final class DnsQuestionContext {
private final long queryTimeoutMillis;
private final CompletableFuture<Void> whenCancelled = new CompletableFuture<>();
private final ScheduledFuture<?> scheduledFuture;
private boolean complete;

DnsQuestionContext(EventExecutor executor, long queryTimeoutMillis) {
this.queryTimeoutMillis = queryTimeoutMillis;
Expand All @@ -48,12 +49,21 @@ boolean isCancelled() {
return whenCancelled.isCompletedExceptionally();
}

void cancel() {
void cancelScheduler() {
if (!scheduledFuture.isDone()) {
scheduledFuture.cancel(false);
}
}

void setComplete() {
complete = true;
cancelScheduler();
}

boolean isCompleted() {
return complete;
}

@Override
public boolean equals(Object o) {
if (this == o) {
Expand All @@ -65,6 +75,7 @@ public boolean equals(Object o) {

final DnsQuestionContext that = (DnsQuestionContext) o;
return queryTimeoutMillis == that.queryTimeoutMillis &&
complete == that.complete &&
whenCancelled.equals(that.whenCancelled) &&
scheduledFuture.equals(that.scheduledFuture);
}
Expand All @@ -74,6 +85,7 @@ public int hashCode() {
int result = whenCancelled.hashCode();
result = 31 * result + scheduledFuture.hashCode();
result = 31 * result + (int) queryTimeoutMillis;
result = 31 * result + (complete ? 1 : 0);
return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.linecorp.armeria.internal.client.dns;

import static com.google.common.base.MoreObjects.firstNonNull;
import static com.google.common.collect.ImmutableList.toImmutableList;

import java.util.List;
Expand All @@ -28,6 +29,7 @@

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;

import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.common.util.AbstractUnwrappable;
Expand Down Expand Up @@ -60,15 +62,18 @@ private static List<String> validateSearchDomain(List<String> searchDomains) {
return null;
}
String normalized = searchDomain;
if (searchDomain.charAt(0) != '.') {
normalized = '.' + searchDomain;
if (searchDomain.charAt(0) == '.') {
// Remove the leading dot.
normalized = searchDomain.substring(1);
}
if (searchDomain.charAt(searchDomain.length() - 1) != '.') {
if (normalized.charAt(normalized.length() - 1) != '.') {
// Add a trailing dot.
normalized += '.';
}
try {
// Try to create a sample DnsQuestion to validate the search domain.
DnsQuestionWithoutTrailingDot.of("localhost" + normalized, DnsRecordType.A);
DnsQuestionWithoutTrailingDot.of("localhost." + normalized,
DnsRecordType.A);
return normalized;
} catch (Exception ex) {
logger.warn("Ignoring a malformed search domain: '{}'", searchDomain, ex);
Expand Down Expand Up @@ -96,6 +101,11 @@ private CompletableFuture<List<DnsRecord>> resolve0(DnsQuestionContext ctx,
new IllegalStateException("resolver is closed already"));
}

if (ctx.isCompleted()) {
// Other DnsRecordType may be resolved already.
return UnmodifiableFuture.completedFuture(ImmutableList.of());
}

return unwrap().resolve(ctx, question).handle((records, cause) -> {
if (records != null) {
return UnmodifiableFuture.completedFuture(records);
Expand Down Expand Up @@ -126,14 +136,18 @@ static final class SearchDomainQuestionContext {
private final DnsQuestion original;
private final String originalName;
private final List<String> searchDomains;
private final int numSearchDomains;
private final boolean shouldStartWithHostname;
private final boolean hasTrailingDot;
private volatile int numAttemptsSoFar;

SearchDomainQuestionContext(DnsQuestion original, List<String> searchDomains, int ndots) {
this.original = original;
this.searchDomains = searchDomains;
numSearchDomains = searchDomains.size();
originalName = original.name();
shouldStartWithHostname = hasNDots(originalName, ndots);
hasTrailingDot = originalName.endsWith(".");
shouldStartWithHostname = hasNDots(originalName, ndots) || hasTrailingDot || numSearchDomains == 0;
}

private static boolean hasNDots(String hostname, int ndots) {
Expand All @@ -157,32 +171,46 @@ DnsQuestion nextQuestion() {
@Nullable
private DnsQuestion nextQuestion0() {
final int numAttemptsSoFar = this.numAttemptsSoFar;
if (numAttemptsSoFar == 0) {
if (originalName.endsWith(".") || searchDomains.isEmpty()) {
return original;
}
if (shouldStartWithHostname) {
return newQuestion(originalName + '.');

final int searchDomainPos;
if (shouldStartWithHostname) {
searchDomainPos = numAttemptsSoFar - 1;
} else {
if (numAttemptsSoFar == numSearchDomains) {
// The last attempt uses the hostname itself.
searchDomainPos = -1;
} else {
return newQuestion(originalName + searchDomains.get(0));
searchDomainPos = numAttemptsSoFar;
}
}

int nextSearchDomainPos = numAttemptsSoFar;
if (shouldStartWithHostname) {
nextSearchDomainPos = numAttemptsSoFar - 1;
if (searchDomainPos >= numSearchDomains) {
// No more search domain to try.
return null;
}

if (nextSearchDomainPos < searchDomains.size()) {
return newQuestion(originalName + searchDomains.get(nextSearchDomainPos));
}
if (nextSearchDomainPos == searchDomains.size() && !shouldStartWithHostname) {
return newQuestion(originalName + '.');
final String searchDomain;
// -1 means the hostname itself.
if (searchDomainPos == -1) {
searchDomain = null;
} else {
searchDomain = searchDomains.get(searchDomainPos);
}
return null;

return newQuestion(searchDomain);
}

private DnsQuestion newQuestion(String hostname) {
private DnsQuestion newQuestion(@Nullable String searchDomain) {
searchDomain = firstNonNull(searchDomain, "");
final String hostname;
if (hasTrailingDot) {
if (searchDomain.isEmpty()) {
return original;
}
hostname = originalName + searchDomain;
} else {
hostname = originalName + '.' + searchDomain;
}
// - As the search domain is validated already, DnsQuestionWithoutTrailingDot should not raise an
// exception.
// - Use originalName to delete the cache value in RefreshingAddressResolver when the DnsQuestion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.collect.ImmutableMap;

Expand All @@ -43,10 +45,13 @@
import io.netty.handler.codec.dns.DefaultDnsResponse;
import io.netty.handler.codec.dns.DnsRecord;
import io.netty.handler.codec.dns.DnsSection;
import io.netty.resolver.ResolvedAddressTypes;
import io.netty.util.ReferenceCountUtil;

class TrailingDotAddressResolverTest {

private static final Logger logger = LoggerFactory.getLogger(TrailingDotAddressResolverTest.class);

@RegisterExtension
static ServerExtension server = new ServerExtension() {
@Override
Expand Down Expand Up @@ -77,13 +82,15 @@ void resolve() throws Exception {
new DefaultDnsQuestion("foo.com.", A),
new DefaultDnsResponse(0).addRecord(ANSWER, newAddressRecord("foo.com.", "127.0.0.1"))),
dnsRecordCaptor)) {
try (ClientFactory factory = ClientFactory.builder()
.domainNameResolverCustomizer(b -> {
b.serverAddresses(dnsServer.addr());
b.searchDomains("search.domain1", "search.domain2");
b.ndots(3);
})
.build()) {
try (ClientFactory factory =
ClientFactory.builder()
.domainNameResolverCustomizer(b -> {
b.resolvedAddressTypes(ResolvedAddressTypes.IPV4_ONLY);
b.serverAddresses(dnsServer.addr());
b.searchDomains("search.domain1", "search.domain2");
b.ndots(3);
})
.build()) {

final BlockingWebClient client = WebClient.builder()
.factory(factory)
Expand All @@ -93,6 +100,7 @@ void resolve() throws Exception {
"http://foo.com.:" + server.httpPort() + '/');
assertThat(response.contentUtf8()).isEqualTo("Hello, world!");
assertThat(dnsRecordCaptor.records).isNotEmpty();
logger.debug("Captured DNS records: {}", dnsRecordCaptor.records);
dnsRecordCaptor.records.forEach(record -> {
assertThat(record.name()).isEqualTo("foo.com.");
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,16 @@ void shouldWaitForPreferredRecords() {
DnsQuestionWithoutTrailingDot.of("foo.com.", DnsRecordType.AAAA));
final Object[] results = new Object[questions.size()];

final DnsQuestionContext ctx = new DnsQuestionContext(CommonPools.workerGroup().next(), Long.MAX_VALUE);
final List<DnsRecord> fooDnsRecord = ImmutableList.of(newAddressRecord("foo.com.", "1.2.3.4"));
final List<DnsRecord> barDnsRecord = ImmutableList.of(newAddressRecord("foo.com.", "2001:db8::1"));
// Should not complete `future` and wait for the first result.
maybeCompletePreferredRecords(future, questions, results, 1, barDnsRecord, null);
maybeCompletePreferredRecords(ctx, future, questions, results, 1, barDnsRecord, null);
assertThat(future).isNotCompleted();
maybeCompletePreferredRecords(future, questions, results, 0, fooDnsRecord, null);
assertThat(ctx.isCompleted()).isFalse();
maybeCompletePreferredRecords(ctx, future, questions, results, 0, fooDnsRecord, null);
assertThat(future).isCompletedWithValue(fooDnsRecord);
assertThat(ctx.isCompleted()).isTrue();
}

@Test
Expand All @@ -216,12 +219,15 @@ void shouldWaitForPreferredRecords_ignoreErrorsOnPrecedence() {
DnsQuestionWithoutTrailingDot.of("foo.com.", DnsRecordType.AAAA));
final Object[] results = new Object[questions.size()];

final DnsQuestionContext ctx = new DnsQuestionContext(CommonPools.workerGroup().next(), Long.MAX_VALUE);
final List<DnsRecord> barDnsRecord = ImmutableList.of(newAddressRecord("foo.com.", "2001:db8::1"));
// Should not complete `future` and wait for the first result.
maybeCompletePreferredRecords(future, questions, results, 1, barDnsRecord, null);
maybeCompletePreferredRecords(ctx, future, questions, results, 1, barDnsRecord, null);
assertThat(future).isNotCompleted();
maybeCompletePreferredRecords(future, questions, results, 0, null, new AnticipatedException());
assertThat(ctx.isCompleted()).isFalse();
maybeCompletePreferredRecords(ctx, future, questions, results, 0, null, new AnticipatedException());
assertThat(future).isCompletedWithValue(barDnsRecord);
assertThat(ctx.isCompleted()).isTrue();
}

@Test
Expand All @@ -232,10 +238,12 @@ void resolvePreferredRecordsFirst() {
DnsQuestionWithoutTrailingDot.of("foo.com.", DnsRecordType.AAAA));
final Object[] results = new Object[questions.size()];

final DnsQuestionContext ctx = new DnsQuestionContext(CommonPools.workerGroup().next(), Long.MAX_VALUE);
final List<DnsRecord> fooDnsRecord = ImmutableList.of(newAddressRecord("foo.com.", "1.2.3.4"));
maybeCompletePreferredRecords(future, questions, results, 0, fooDnsRecord, null);
maybeCompletePreferredRecords(ctx, future, questions, results, 0, fooDnsRecord, null);
// The preferred question is resolved. Don't need to wait for the questions.
assertThat(future).isCompletedWithValue(fooDnsRecord);
assertThat(ctx.isCompleted()).isTrue();
}

@Test
Expand All @@ -249,11 +257,14 @@ void shouldWaitForPreferredRecords_allQuestionsAreFailed() {
final List<DnsRecord> barDnsRecord = ImmutableList.of(newAddressRecord("foo.com.", "2001:db8::1"));
// Should not complete `future` and wait for the first result.
final AnticipatedException barCause = new AnticipatedException();
maybeCompletePreferredRecords(future, questions, results, 1, barDnsRecord, barCause);
final DnsQuestionContext ctx = new DnsQuestionContext(CommonPools.workerGroup().next(), Long.MAX_VALUE);
maybeCompletePreferredRecords(ctx, future, questions, results, 1, barDnsRecord, barCause);
assertThat(future).isNotCompleted();
assertThat(ctx.isCompleted()).isFalse();
final AnticipatedException fooCause = new AnticipatedException();
maybeCompletePreferredRecords(future, questions, results, 0, null, fooCause);
maybeCompletePreferredRecords(ctx, future, questions, results, 0, null, fooCause);
assertThat(future).isCompletedExceptionally();
assertThat(ctx.isCompleted()).isTrue();
assertThatThrownBy(future::join)
.isInstanceOf(CompletionException.class)
.cause()
Expand Down
Loading

0 comments on commit 460ea02

Please sign in to comment.