Skip to content

Commit

Permalink
Add non-blocking DNS resolver for Android API requests
Browse files Browse the repository at this point in the history
  • Loading branch information
dlon committed Nov 20, 2024
1 parent 3bc6b9f commit d8b6010
Show file tree
Hide file tree
Showing 16 changed files with 460 additions and 260 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions mullvad-api/src/bin/relay_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
//! Used by the installer artifact packer to bundle the latest available
//! relay list at the time of creating the installer.
use mullvad_api::{proxy::ApiConnectionMode, rest::Error as RestError, RelayListProxy};
use mullvad_api::{
proxy::ApiConnectionMode, rest::Error as RestError, DefaultDnsResolver, RelayListProxy,
};
use std::process;
use talpid_types::ErrorExt;

#[tokio::main]
async fn main() {
let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current())
let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current(), DefaultDnsResolver)
.expect("Failed to load runtime");

let relay_list_request =
Expand Down
33 changes: 16 additions & 17 deletions mullvad-api/src/https_client_with_sni.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,14 @@ use crate::{
abortable_stream::{AbortableStream, AbortableStreamHandle},
proxy::{ApiConnection, ApiConnectionMode, ProxyConfig},
tls_stream::TlsStream,
AddressCache,
AddressCache, DnsResolver,
};
use futures::{channel::mpsc, future, pin_mut, StreamExt};
#[cfg(target_os = "android")]
use futures::{channel::oneshot, sink::SinkExt};
use http::uri::Scheme;
use hyper::Uri;
use hyper_util::{
client::legacy::connect::dns::{GaiResolver, Name},
rt::TokioIo,
};
use hyper_util::rt::TokioIo;
use mullvad_encrypted_dns_proxy::{
config::ProxyConfig as EncryptedDNSConfig, Forwarder as EncryptedDNSForwarder,
};
Expand Down Expand Up @@ -291,6 +288,7 @@ pub struct HttpsConnectorWithSni {
sni_hostname: Option<String>,
address_cache: AddressCache,
abort_notify: Arc<tokio::sync::Notify>,
dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")]
socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
}
Expand All @@ -307,6 +305,7 @@ impl HttpsConnectorWithSni {
pub fn new(
sni_hostname: Option<String>,
address_cache: AddressCache,
dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> (Self, HttpsConnectorWithSniHandle) {
let (tx, mut rx) = mpsc::unbounded();
Expand Down Expand Up @@ -355,6 +354,7 @@ impl HttpsConnectorWithSni {
sni_hostname,
address_cache,
abort_notify,
dns_resolver,
#[cfg(target_os = "android")]
socket_bypass_tx,
},
Expand Down Expand Up @@ -388,7 +388,11 @@ impl HttpsConnectorWithSni {
.map_err(|err| io::Error::new(io::ErrorKind::TimedOut, err))?
}

async fn resolve_address(address_cache: AddressCache, uri: Uri) -> io::Result<SocketAddr> {
async fn resolve_address(
address_cache: AddressCache,
dns_resolver: &dyn DnsResolver,
uri: Uri,
) -> io::Result<SocketAddr> {
const DEFAULT_PORT: u16 = 443;

let hostname = uri.host().ok_or_else(|| {
Expand All @@ -408,19 +412,13 @@ impl HttpsConnectorWithSni {
));
}

// Use getaddrinfo as a fallback
// Use DNS resolution as fallback
//
let mut addrs = GaiResolver::new()
.call(
Name::from_str(hostname)
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?,
)
.await
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
let addrs = dns_resolver.resolve(hostname.to_owned()).await?;
let addr = addrs
.next()
.get(0)
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Empty DNS response"))?;
Ok(SocketAddr::new(addr.ip(), port.unwrap_or(DEFAULT_PORT)))
Ok(SocketAddr::new(*addr, port.unwrap_or(DEFAULT_PORT)))
}
}

Expand Down Expand Up @@ -455,6 +453,7 @@ impl Service<Uri> for HttpsConnectorWithSni {
#[cfg(target_os = "android")]
let socket_bypass_tx = self.socket_bypass_tx.clone();
let address_cache = self.address_cache.clone();
let dns_resolver = self.dns_resolver.clone();

let fut = async move {
if uri.scheme() != Some(&Scheme::HTTPS) {
Expand All @@ -465,7 +464,7 @@ impl Service<Uri> for HttpsConnectorWithSni {
}

let hostname = sni_hostname?;
let addr = Self::resolve_address(address_cache, uri).await?;
let addr = Self::resolve_address(address_cache, &*dns_resolver, uri).await?;

// Loop until we have established a connection. This starts over if a new endpoint
// is selected while connecting.
Expand Down
45 changes: 43 additions & 2 deletions mullvad-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ use std::{
cell::Cell,
collections::BTreeMap,
future::Future,
io,
net::{IpAddr, Ipv4Addr, SocketAddr},
ops::Deref,
path::Path,
sync::OnceLock,
pin::Pin,
sync::{Arc, OnceLock},
};
use talpid_types::ErrorExt;

Expand Down Expand Up @@ -304,11 +306,37 @@ impl ApiEndpoint {
}
}

pub trait DnsResolver: 'static + Send + Sync {
fn resolve(
&self,
host: String,
) -> Pin<Box<dyn Future<Output = io::Result<Vec<IpAddr>>> + Send>>;
}

pub struct DefaultDnsResolver;

impl DnsResolver for DefaultDnsResolver {
fn resolve(
&self,
host: String,
) -> Pin<Box<dyn Future<Output = io::Result<Vec<IpAddr>>> + Send>> {
use std::net::ToSocketAddrs;

Box::pin(async move {
let addrs = tokio::task::spawn_blocking(move || (host, 0).to_socket_addrs())
.await
.expect("DNS task panicked")?;
Ok(addrs.map(|addr| addr.ip()).collect())
})
}
}

/// A type that helps with the creation of API connections.
pub struct Runtime {
handle: tokio::runtime::Handle,
address_cache: AddressCache,
api_availability: availability::ApiAvailability,
dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")]
socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
}
Expand All @@ -323,13 +351,20 @@ pub enum Error {

#[error("API availability check failed")]
ApiCheckError(#[from] availability::Error),

#[error("DNS resolution error")]
ResolutionFailed(#[from] std::io::Error),
}

impl Runtime {
/// Create a new `Runtime`.
pub fn new(handle: tokio::runtime::Handle) -> Result<Self, Error> {
pub fn new(
handle: tokio::runtime::Handle,
dns_resolver: impl DnsResolver,
) -> Result<Self, Error> {
Self::new_inner(
handle,
dns_resolver,
#[cfg(target_os = "android")]
None,
)
Expand All @@ -346,12 +381,14 @@ impl Runtime {

fn new_inner(
handle: tokio::runtime::Handle,
dns_resolver: impl DnsResolver,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> Result<Self, Error> {
Ok(Runtime {
handle,
address_cache: AddressCache::new(None)?,
api_availability: ApiAvailability::default(),
dns_resolver: Arc::new(dns_resolver),
#[cfg(target_os = "android")]
socket_bypass_tx,
})
Expand All @@ -360,11 +397,13 @@ impl Runtime {
/// Create a new `Runtime` using the specified directories.
/// Try to use the cache directory first, and fall back on the bundled address otherwise.
pub async fn with_cache(
dns_resolver: impl DnsResolver,
cache_dir: &Path,
write_changes: bool,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> Result<Self, Error> {
let handle = tokio::runtime::Handle::current();

#[cfg(feature = "api-override")]
if API.disable_address_cache {
return Self::new_inner(
Expand Down Expand Up @@ -402,6 +441,7 @@ impl Runtime {
handle,
address_cache,
api_availability,
dns_resolver: Arc::new(dns_resolver),
#[cfg(target_os = "android")]
socket_bypass_tx,
})
Expand All @@ -419,6 +459,7 @@ impl Runtime {
self.api_availability.clone(),
self.address_cache.clone(),
connection_mode_provider,
self.dns_resolver.clone(),
#[cfg(target_os = "android")]
socket_bypass_tx,
)
Expand Down
3 changes: 3 additions & 0 deletions mullvad-api/src/rest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::{
availability::ApiAvailability,
https_client_with_sni::{HttpsConnectorWithSni, HttpsConnectorWithSniHandle},
proxy::ConnectionModeProvider,
DnsResolver,
};
use futures::{
channel::{mpsc, oneshot},
Expand Down Expand Up @@ -154,11 +155,13 @@ impl<T: ConnectionModeProvider + 'static> RequestService<T> {
api_availability: ApiAvailability,
address_cache: AddressCache,
connection_mode_provider: T,
dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> RequestServiceHandle {
let (connector, connector_handle) = HttpsConnectorWithSni::new(
sni_hostname,
address_cache.clone(),
dns_resolver,
#[cfg(target_os = "android")]
socket_bypass_tx.clone(),
);
Expand Down
1 change: 1 addition & 0 deletions mullvad-daemon/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ tokio = { workspace = true, features = ["test-util"] }

[target.'cfg(target_os="android")'.dependencies]
android_logger = "0.8"
hickory-resolver = { version = "0.24.1" }

[target.'cfg(unix)'.dependencies]
nix = "0.23"
Expand Down
53 changes: 53 additions & 0 deletions mullvad-daemon/src/android_dns.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#![cfg(target_os = "android")]
//! A non-blocking DNS resolver. `getaddrinfo` tends to prevent the tokio runtime from being
//! dropped, since it waits indefinitely on blocking threads. This is particularly bad on Android,
//! so we use a non-blocking resolver instead.
use hickory_resolver::{
config::{NameServerConfigGroup, ResolverConfig, ResolverOpts},
TokioAsyncResolver,
};
use mullvad_api::DnsResolver;
use std::{future::Future, io, net::IpAddr, pin::Pin};

pub struct AndroidDnsResolver {
connectivity_listener: talpid_core::connectivity_listener::ConnectivityListener,
}

impl AndroidDnsResolver {
pub fn new(
connectivity_listener: talpid_core::connectivity_listener::ConnectivityListener,
) -> Self {
Self {
connectivity_listener,
}
}
}

impl DnsResolver for AndroidDnsResolver {
fn resolve(
&self,
host: String,
) -> Pin<Box<dyn Future<Output = io::Result<Vec<IpAddr>>> + Send>> {
let ips = self.connectivity_listener.current_dns_servers();

Box::pin(async move {
let ips = ips.map_err(|err| {
io::Error::new(
io::ErrorKind::Other,
format!("Failed to retrieve current servers: {err}"),
)
})?;
let group = NameServerConfigGroup::from_ips_clear(&ips, 53, false);

let config = ResolverConfig::from_parts(None, vec![], group);
let resolver = TokioAsyncResolver::tokio(config, ResolverOpts::default());

let lookup = resolver.lookup_ip(host).await.map_err(|err| {
io::Error::new(io::ErrorKind::Other, format!("lookup_ip failed: {err}"))
})?;

Ok(lookup.into_iter().collect())
})
}
}
25 changes: 25 additions & 0 deletions mullvad-daemon/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

mod access_method;
pub mod account_history;
mod android_dns;
mod api;
mod api_address_updater;
#[cfg(not(target_os = "android"))]
Expand Down Expand Up @@ -38,6 +39,8 @@ use futures::{
};
use geoip::GeoIpHandler;
use management_interface::ManagementInterfaceServer;
#[cfg(not(target_os = "android"))]
use mullvad_api::DefaultDnsResolver;
use mullvad_relay_selector::{RelaySelector, SelectorConfig};
#[cfg(target_os = "android")]
use mullvad_types::account::{PlayPurchase, PlayPurchasePaymentToken};
Expand Down Expand Up @@ -91,6 +94,9 @@ use talpid_types::{
};
use tokio::io;

#[cfg(target_os = "android")]
use talpid_core::connectivity_listener::ConnectivityListener;

/// Delay between generating a new WireGuard key and reconnecting
const WG_RECONNECT_DELAY: Duration = Duration::from_secs(4 * 60);

Expand Down Expand Up @@ -604,8 +610,25 @@ impl Daemon {

let (internal_event_tx, internal_event_rx) = daemon_command_channel.destructure();

#[cfg(target_os = "android")]
let connectivity_listener = match ConnectivityListener::new(android_context.clone()) {
Ok(listener) => listener,
Err(error) => {
log::warn!(
"{}",
error.display_chain_with_msg("Failed to start connectivity listener")
);
return Err(Error::DaemonUnavailable);
}
};

mullvad_api::proxy::ApiConnectionMode::try_delete_cache(&cache_dir).await;
let api_runtime = mullvad_api::Runtime::with_cache(
// FIXME: clone is bad (single sender)
#[cfg(target_os = "android")]
android_dns::AndroidDnsResolver::new(connectivity_listener.clone()),
#[cfg(not(target_os = "android"))]
DefaultDnsResolver,
&cache_dir,
true,
#[cfg(target_os = "android")]
Expand Down Expand Up @@ -777,6 +800,8 @@ impl Daemon {
volume_update_rx,
#[cfg(target_os = "android")]
android_context,
#[cfg(target_os = "android")]
connectivity_listener,
#[cfg(target_os = "linux")]
tunnel_state_machine::LinuxNetworkingIdentifiers {
fwmark: mullvad_types::TUNNEL_FWMARK,
Expand Down
Loading

0 comments on commit d8b6010

Please sign in to comment.