From d4cf37d017c9643f51a0c56daf17f24782fc907d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20G=C3=B6ransson?= Date: Wed, 20 Nov 2024 08:45:25 +0100 Subject: [PATCH] Refactor ConnectivityListener --- android/lib/talpid/build.gradle.kts | 1 + .../mullvad/talpid/ConnectivityListener.kt | 129 +++++++++-------- .../net/mullvad/talpid/TalpidVpnService.kt | 8 +- .../talpid/util/ConnectivityManagerUtil.kt | 132 ++++++++++++++++++ .../mullvadvpn/service/MullvadVpnService.kt | 2 +- 5 files changed, 209 insertions(+), 63 deletions(-) create mode 100644 android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerUtil.kt diff --git a/android/lib/talpid/build.gradle.kts b/android/lib/talpid/build.gradle.kts index a5cd613de189..c53c2add28dd 100644 --- a/android/lib/talpid/build.gradle.kts +++ b/android/lib/talpid/build.gradle.kts @@ -31,6 +31,7 @@ android { dependencies { implementation(projects.lib.model) + implementation(libs.androidx.ktx) implementation(libs.androidx.lifecycle.service) implementation(libs.kermit) implementation(libs.kotlin.stdlib) diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt index f1fe3ca807b4..a37cf18578df 100644 --- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/ConnectivityListener.kt @@ -1,86 +1,95 @@ package net.mullvad.talpid -import android.content.Context import android.net.ConnectivityManager -import android.net.ConnectivityManager.NetworkCallback import android.net.LinkProperties import android.net.Network import android.net.NetworkCapabilities import android.net.NetworkRequest -import co.touchlab.kermit.Logger import java.net.InetAddress -import kotlin.properties.Delegates.observable - -class ConnectivityListener { - private val availableNetworks = HashSet() - - private val callback = - object : NetworkCallback() { - override fun onAvailable(network: Network) { - availableNetworks.add(network) - isConnected = true - } - - override fun onLost(network: Network) { - availableNetworks.remove(network) - isConnected = availableNetworks.isNotEmpty() - } - } - - private val defaultNetworkCallback = - object : NetworkCallback() { - override fun onLinkPropertiesChanged(network: Network, linkProperties: LinkProperties) { - super.onLinkPropertiesChanged(network, linkProperties) - currentDnsServers = ArrayList(linkProperties.dnsServers) +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.SharingStarted +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.distinctUntilChanged +import kotlinx.coroutines.flow.filterIsInstance +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.onEach +import kotlinx.coroutines.flow.scan +import kotlinx.coroutines.flow.stateIn +import net.mullvad.talpid.util.NetworkEvent +import net.mullvad.talpid.util.defaultNetworkFlow +import net.mullvad.talpid.util.networkFlow + +class ConnectivityListener(val connectivityManager: ConnectivityManager) { + // Used by JNI + var senderAddress = 0L + set(value) { + if (value == 0L) { + destroySender(field) } + field = value } - private lateinit var connectivityManager: ConnectivityManager + private lateinit var _isConnected: StateFlow + // Used by JNI + val isConnected + get() = _isConnected.value + private lateinit var _currentDnsServers: StateFlow> // Used by JNI - var isConnected by - observable(false) { _, oldValue, newValue -> - if (newValue != oldValue) { - if (senderAddress != 0L) { - notifyConnectivityChange(newValue, senderAddress) + val currentDnsServers + get() = ArrayList(_currentDnsServers.value) + + fun register(scope: CoroutineScope) { + _currentDnsServers = + dnsServerChanges().stateIn(scope, SharingStarted.Eagerly, currentDnsServers()) + + _isConnected = + hasInternetCapability() + .onEach { + if (senderAddress != 0L) { + notifyConnectivityChange(it, senderAddress) + } } - } - } + .stateIn(scope, SharingStarted.Eagerly, false) + } - var currentDnsServers: ArrayList = ArrayList() - private set(value) { - field = ArrayList(value.filter { it.hostAddress != TalpidVpnService.FALLBACK_DUMMY_DNS_SERVER }) - Logger.d("New currentDnsServers: $field") - } + fun unregister() { + senderAddress = 0L + } - var senderAddress = 0L + private fun dnsServerChanges(): Flow> = + connectivityManager + .defaultNetworkFlow() + .filterIsInstance() + .map { it.linkProperties.dnsServersWithoutFallback() } + + private fun currentDnsServers(): List = + connectivityManager + .getLinkProperties(connectivityManager.activeNetwork) + ?.dnsServersWithoutFallback() ?: emptyList() - fun register(context: Context) { + private fun LinkProperties.dnsServersWithoutFallback(): List = + dnsServers.filter { it.hostAddress != TalpidVpnService.FALLBACK_DUMMY_DNS_SERVER } + + private fun hasInternetCapability(): Flow { val request = NetworkRequest.Builder() .addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET) .addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN) .build() - connectivityManager = - context.getSystemService(Context.CONNECTIVITY_SERVICE) as ConnectivityManager - - connectivityManager.registerNetworkCallback(request, callback) - currentDnsServers = - connectivityManager.getLinkProperties(connectivityManager.activeNetwork)?.dnsServers?.let { ArrayList(it) } - ?: ArrayList() - connectivityManager.registerDefaultNetworkCallback(defaultNetworkCallback) - } - - fun unregister() { - connectivityManager.unregisterNetworkCallback(callback) - connectivityManager.unregisterNetworkCallback(defaultNetworkCallback) - - if (senderAddress != 0L) { - var oldSender = senderAddress - senderAddress = 0L - destroySender(oldSender) - } + return connectivityManager + .networkFlow(request) + .scan(setOf()) { networks, event -> + when (event) { + is NetworkEvent.Available -> networks + event.network + is NetworkEvent.Lost -> networks - event.network + else -> networks + } + } + .map { it.isNotEmpty() } + .distinctUntilChanged() } private external fun notifyConnectivityChange(isConnected: Boolean, senderAddress: Long) diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt index 61c0be2ccf65..dfd6699b1e33 100644 --- a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/TalpidVpnService.kt @@ -1,7 +1,10 @@ package net.mullvad.talpid +import android.net.ConnectivityManager import android.os.ParcelFileDescriptor import androidx.annotation.CallSuper +import androidx.core.content.getSystemService +import androidx.lifecycle.lifecycleScope import co.touchlab.kermit.Logger import java.net.Inet4Address import java.net.Inet6Address @@ -29,12 +32,13 @@ open class TalpidVpnService : LifecycleVpnService() { private var currentTunConfig: TunConfig? = null // Used by JNI - val connectivityListener = ConnectivityListener() + lateinit var connectivityListener: ConnectivityListener @CallSuper override fun onCreate() { super.onCreate() - connectivityListener.register(this) + connectivityListener = ConnectivityListener(getSystemService()!!) + connectivityListener.register(lifecycleScope) } @CallSuper diff --git a/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerUtil.kt b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerUtil.kt new file mode 100644 index 000000000000..a98edd8b5756 --- /dev/null +++ b/android/lib/talpid/src/main/kotlin/net/mullvad/talpid/util/ConnectivityManagerUtil.kt @@ -0,0 +1,132 @@ +package net.mullvad.talpid.util + +import android.net.ConnectivityManager +import android.net.ConnectivityManager.NetworkCallback +import android.net.LinkProperties +import android.net.Network +import android.net.NetworkCapabilities +import android.net.NetworkRequest +import kotlinx.coroutines.channels.awaitClose +import kotlinx.coroutines.channels.trySendBlocking +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.callbackFlow + +sealed interface NetworkEvent { + data class Available(val network: Network) : NetworkEvent + + data object Unavailable : NetworkEvent + + data class LinkPropertiesChanged(val network: Network, val linkProperties: LinkProperties) : + NetworkEvent + + data class CapabilitiesChanged( + val network: Network, + val networkCapabilities: NetworkCapabilities, + ) : NetworkEvent + + data class BlockedStatusChanged(val network: Network, val blocked: Boolean) : NetworkEvent + + data class Losing(val network: Network, val maxMsToLive: Int) : NetworkEvent + + data class Lost(val network: Network) : NetworkEvent +} + +fun ConnectivityManager.defaultNetworkFlow(): Flow = + callbackFlow { + val callback = + object : NetworkCallback() { + override fun onLinkPropertiesChanged( + network: Network, + linkProperties: LinkProperties, + ) { + super.onLinkPropertiesChanged(network, linkProperties) + trySendBlocking(NetworkEvent.LinkPropertiesChanged(network, linkProperties)) + } + + override fun onAvailable(network: Network) { + super.onAvailable(network) + trySendBlocking(NetworkEvent.Available(network)) + } + + override fun onCapabilitiesChanged( + network: Network, + networkCapabilities: NetworkCapabilities, + ) { + super.onCapabilitiesChanged(network, networkCapabilities) + trySendBlocking(NetworkEvent.CapabilitiesChanged(network, networkCapabilities)) + } + + override fun onBlockedStatusChanged(network: Network, blocked: Boolean) { + super.onBlockedStatusChanged(network, blocked) + trySendBlocking(NetworkEvent.BlockedStatusChanged(network, blocked)) + } + + override fun onLosing(network: Network, maxMsToLive: Int) { + super.onLosing(network, maxMsToLive) + trySendBlocking(NetworkEvent.Losing(network, maxMsToLive)) + } + + override fun onLost(network: Network) { + super.onLost(network) + trySendBlocking(NetworkEvent.Lost(network)) + } + + override fun onUnavailable() { + super.onUnavailable() + trySendBlocking(NetworkEvent.Unavailable) + } + } + registerDefaultNetworkCallback(callback) + + awaitClose { unregisterNetworkCallback(callback) } + } + +fun ConnectivityManager.networkFlow(networkRequest: NetworkRequest): Flow = + callbackFlow { + val callback = + object : NetworkCallback() { + override fun onLinkPropertiesChanged( + network: Network, + linkProperties: LinkProperties, + ) { + super.onLinkPropertiesChanged(network, linkProperties) + trySendBlocking(NetworkEvent.LinkPropertiesChanged(network, linkProperties)) + } + + override fun onAvailable(network: Network) { + super.onAvailable(network) + trySendBlocking(NetworkEvent.Available(network)) + } + + override fun onCapabilitiesChanged( + network: Network, + networkCapabilities: NetworkCapabilities, + ) { + super.onCapabilitiesChanged(network, networkCapabilities) + trySendBlocking(NetworkEvent.CapabilitiesChanged(network, networkCapabilities)) + } + + override fun onBlockedStatusChanged(network: Network, blocked: Boolean) { + super.onBlockedStatusChanged(network, blocked) + trySendBlocking(NetworkEvent.BlockedStatusChanged(network, blocked)) + } + + override fun onLosing(network: Network, maxMsToLive: Int) { + super.onLosing(network, maxMsToLive) + trySendBlocking(NetworkEvent.Losing(network, maxMsToLive)) + } + + override fun onLost(network: Network) { + super.onLost(network) + trySendBlocking(NetworkEvent.Lost(network)) + } + + override fun onUnavailable() { + super.onUnavailable() + trySendBlocking(NetworkEvent.Unavailable) + } + } + registerNetworkCallback(networkRequest, callback) + + awaitClose { unregisterNetworkCallback(callback) } + } diff --git a/android/service/src/main/kotlin/net/mullvad/mullvadvpn/service/MullvadVpnService.kt b/android/service/src/main/kotlin/net/mullvad/mullvadvpn/service/MullvadVpnService.kt index ebdcbec78019..55aa416e537b 100644 --- a/android/service/src/main/kotlin/net/mullvad/mullvadvpn/service/MullvadVpnService.kt +++ b/android/service/src/main/kotlin/net/mullvad/mullvadvpn/service/MullvadVpnService.kt @@ -203,6 +203,7 @@ class MullvadVpnService : TalpidVpnService() { } override fun onDestroy() { + super.onDestroy() Logger.i("MullvadVpnService: onDestroy") // Shutting down the daemon gracefully managementService.stop() @@ -214,7 +215,6 @@ class MullvadVpnService : TalpidVpnService() { managementService.enterIdle() Logger.i("Shutdown complete") - super.onDestroy() } // If an intent is from the system it is because of the OS starting/stopping the VPN.