Skip to content

Commit

Permalink
Initialize mullvad daemon directly in vpn service on create
Browse files Browse the repository at this point in the history
As a consequence move the address resolve to the daemon
  • Loading branch information
Pururun committed Aug 30, 2024
1 parent 26b80bf commit 9b68474
Show file tree
Hide file tree
Showing 16 changed files with 168 additions and 197 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
package net.mullvad.mullvadvpn.lib.common.constant

const val GRPC_SOCKET_FILE_NAMED_ARGUMENT = "RPC_SOCKET"
const val FILES_DIR_NAMED_ARGUMENT = "FILES_DIR"
const val CACHE_DIR_NAMED_ARGUMENT = "CACHE_DIR"

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ import android.os.Build

private const val OVERRIDE_API_EXTRA_NAME = "override_api"

fun Intent.putApiEndpointConfigurationExtra(apiEndpointConfiguration: ApiEndpointConfiguration) {
putExtra(OVERRIDE_API_EXTRA_NAME, apiEndpointConfiguration)
fun Intent.putApiEndpointConfigurationExtra(apiEndpointOverrideConfiguration: ApiEndpointOverride) {
putExtra(OVERRIDE_API_EXTRA_NAME, apiEndpointOverrideConfiguration)
}

fun Intent.getApiEndpointConfigurationExtras(): ApiEndpointConfiguration? {
fun Intent.getApiEndpointConfigurationExtras(): ApiEndpointOverride? {
return if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) {
getParcelableExtra(OVERRIDE_API_EXTRA_NAME, ApiEndpointConfiguration::class.java)
getParcelableExtra(OVERRIDE_API_EXTRA_NAME, ApiEndpointOverride::class.java)
} else {
getParcelableExtra(OVERRIDE_API_EXTRA_NAME)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package net.mullvad.mullvadvpn.lib.endpoint

import android.os.Parcelable
import kotlinx.parcelize.Parcelize

@Parcelize
data class ApiEndpointOverride(
val hostname: String,
val port: Int = CUSTOM_ENDPOINT_HTTPS_PORT,
val disableAddressCache: Boolean = true,
val disableTls: Boolean = false,
val forceDirectConnection: Boolean = true,
) : Parcelable {
companion object {
const val CUSTOM_ENDPOINT_HTTPS_PORT = 443
}
}

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package net.mullvad.mullvadvpn.service

import java.io.File
import net.mullvad.mullvadvpn.lib.endpoint.ApiEndpointOverride

data class DaemonConfig(
val rpcSocket: File,
val filesDir: File,
val cacheDir: File,
val apiEndpointOverride: ApiEndpointOverride?,
)
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package net.mullvad.mullvadvpn.service

import net.mullvad.mullvadvpn.lib.endpoint.ApiEndpoint
import net.mullvad.mullvadvpn.lib.endpoint.ApiEndpointOverride

object MullvadDaemon {
init {
Expand All @@ -12,7 +12,7 @@ object MullvadDaemon {
rpcSocketPath: String,
filesDirectory: String,
cacheDirectory: String,
apiEndpoint: ApiEndpoint?,
apiEndpointOverride: ApiEndpointOverride?,
)

external fun shutdown()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,17 @@ import androidx.lifecycle.lifecycleScope
import arrow.atomic.AtomicInt
import co.touchlab.kermit.Logger
import java.io.File
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.filter
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import net.mullvad.mullvadvpn.lib.common.constant.GRPC_SOCKET_FILE_NAMED_ARGUMENT
import net.mullvad.mullvadvpn.lib.common.constant.KEY_CONNECT_ACTION
import net.mullvad.mullvadvpn.lib.common.constant.KEY_DISCONNECT_ACTION
import net.mullvad.mullvadvpn.lib.daemon.grpc.ManagementService
import net.mullvad.mullvadvpn.lib.endpoint.ApiEndpointConfiguration
import net.mullvad.mullvadvpn.lib.endpoint.getApiEndpointConfigurationExtras
import net.mullvad.mullvadvpn.lib.intent.IntentProvider
import net.mullvad.mullvadvpn.lib.model.TunnelState
import net.mullvad.mullvadvpn.lib.shared.ConnectionProxy
import net.mullvad.mullvadvpn.service.di.apiEndpointModule
import net.mullvad.mullvadvpn.service.di.vpnServiceModule
import net.mullvad.mullvadvpn.service.migration.MigrateSplitTunneling
import net.mullvad.mullvadvpn.service.notifications.ForegroundNotificationManager
Expand All @@ -34,20 +30,18 @@ import net.mullvad.mullvadvpn.service.notifications.NotificationManager
import net.mullvad.talpid.TalpidVpnService
import org.koin.android.ext.android.getKoin
import org.koin.core.context.loadKoinModules
import org.koin.core.qualifier.named

private const val RELAYS_FILE = "relays.json"

class MullvadVpnService : TalpidVpnService() {

private lateinit var keyguardManager: KeyguardManager

private lateinit var apiEndpointConfiguration: ApiEndpointConfiguration
private lateinit var managementService: ManagementService
private lateinit var migrateSplitTunneling: MigrateSplitTunneling
private lateinit var intentProvider: IntentProvider
private lateinit var connectionProxy: ConnectionProxy
private lateinit var rpcSocketFile: File
private lateinit var daemonConfig: DaemonConfig

private lateinit var foregroundNotificationHandler: ForegroundNotificationManager

Expand All @@ -59,7 +53,7 @@ class MullvadVpnService : TalpidVpnService() {
super.onCreate()
Logger.i("MullvadVpnService: onCreate")

loadKoinModules(listOf(vpnServiceModule, apiEndpointModule))
loadKoinModules(listOf(vpnServiceModule))
with(getKoin()) {
// Needed to create all the notification channels
get<NotificationChannelFactory>()
Expand All @@ -70,27 +64,32 @@ class MullvadVpnService : TalpidVpnService() {
ForegroundNotificationManager(this@MullvadVpnService, get())
get<NotificationManager>()

apiEndpointConfiguration = get()
daemonConfig = get()
migrateSplitTunneling = get()
intentProvider = get()
connectionProxy = get()
rpcSocketFile = get(named(GRPC_SOCKET_FILE_NAMED_ARGUMENT))
}

keyguardManager = getSystemService<KeyguardManager>()!!

// TODO We should avoid lifecycleScope.launch (current needed due to InetSocketAddress
// with intent from API)
lifecycleScope.launch(context = Dispatchers.IO) {
prepareFiles(this@MullvadVpnService)
migrateSplitTunneling.migrate()
prepareFiles(this@MullvadVpnService)
migrateSplitTunneling.migrate()

Logger.i("Start daemon")
startDaemon()
// If it is a debug build and we have an api override in the intent, use it
// This is for injecting hostname and port for our mock api tests
val intentApiOverride =
intentProvider.getLatestIntent()?.getApiEndpointConfigurationExtras()
val updatedConfig =
if (BuildConfig.DEBUG && intentApiOverride != null) {
daemonConfig.copy(apiEndpointOverride = intentApiOverride)
} else {
daemonConfig
}
Logger.i("Start daemon")
startDaemon(updatedConfig)

Logger.i("Start management service")
managementService.start()
}
Logger.i("Start management service")
managementService.start()
}

override fun onStartCommand(intent: Intent?, flags: Int, startId: Int): Int {
Expand Down Expand Up @@ -147,24 +146,17 @@ class MullvadVpnService : TalpidVpnService() {
}
}

private fun startDaemon() {
val apiEndpointConfiguration =
if (BuildConfig.DEBUG) {
intentProvider.getLatestIntent()?.getApiEndpointConfigurationExtras()
?: apiEndpointConfiguration
} else {
apiEndpointConfiguration
}

MullvadDaemon.initialize(
vpnService = this@MullvadVpnService,
rpcSocketPath = rpcSocketFile.absolutePath,
filesDirectory = filesDir.absolutePath,
cacheDirectory = cacheDir.absolutePath,
apiEndpoint = apiEndpointConfiguration.apiEndpoint(),
)
Logger.i("MullvadVpnService: Daemon initialized")
}
private fun startDaemon(daemonConfig: DaemonConfig) =
with(daemonConfig) {
MullvadDaemon.initialize(
vpnService = this@MullvadVpnService,
rpcSocketPath = rpcSocket.absolutePath,
filesDirectory = filesDir.absolutePath,
cacheDirectory = cacheDir.absolutePath,
apiEndpointOverride = apiEndpointOverride,
)
Logger.i("MullvadVpnService: Daemon initialized")
}

private fun emptyBinder() =
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) {
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@ package net.mullvad.mullvadvpn.service.di

import androidx.core.app.NotificationManagerCompat
import kotlinx.coroutines.MainScope
import net.mullvad.mullvadvpn.lib.common.constant.CACHE_DIR_NAMED_ARGUMENT
import net.mullvad.mullvadvpn.lib.common.constant.FILES_DIR_NAMED_ARGUMENT
import net.mullvad.mullvadvpn.lib.common.constant.GRPC_SOCKET_FILE_NAMED_ARGUMENT
import net.mullvad.mullvadvpn.lib.endpoint.ApiEndpointOverride
import net.mullvad.mullvadvpn.lib.model.NotificationChannel
import net.mullvad.mullvadvpn.service.BuildConfig
import net.mullvad.mullvadvpn.service.DaemonConfig
import net.mullvad.mullvadvpn.service.migration.MigrateSplitTunneling
import net.mullvad.mullvadvpn.service.notifications.NotificationChannelFactory
import net.mullvad.mullvadvpn.service.notifications.NotificationManager
Expand All @@ -12,12 +18,15 @@ import net.mullvad.mullvadvpn.service.notifications.tunnelstate.TunnelStateNotif
import org.koin.android.ext.koin.androidContext
import org.koin.core.module.dsl.createdAtStart
import org.koin.core.module.dsl.withOptions
import org.koin.core.qualifier.named
import org.koin.dsl.bind
import org.koin.dsl.module

val vpnServiceModule = module {
single { NotificationManagerCompat.from(androidContext()) }
single { androidContext().resources }
single(named(FILES_DIR_NAMED_ARGUMENT)) { androidContext().filesDir }
single(named(CACHE_DIR_NAMED_ARGUMENT)) { androidContext().cacheDir }

single { NotificationChannel.TunnelUpdates } bind NotificationChannel::class
single { NotificationChannel.AccountUpdates } bind NotificationChannel::class
Expand Down Expand Up @@ -46,4 +55,18 @@ val vpnServiceModule = module {
}

single { MigrateSplitTunneling(androidContext()) }

single {
DaemonConfig(
rpcSocket = get(named(GRPC_SOCKET_FILE_NAMED_ARGUMENT)),
filesDir = get(named(FILES_DIR_NAMED_ARGUMENT)),
cacheDir = get(named(CACHE_DIR_NAMED_ARGUMENT)),
apiEndpointOverride =
if (BuildConfig.FLAVOR_infrastructure != "prod") {
ApiEndpointOverride(BuildConfig.API_ENDPOINT)
} else {
null
},
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import android.widget.Button
import androidx.test.uiautomator.By
import androidx.test.uiautomator.UiDevice
import androidx.test.uiautomator.Until
import net.mullvad.mullvadvpn.lib.endpoint.CustomApiEndpointConfiguration
import net.mullvad.mullvadvpn.lib.endpoint.ApiEndpointOverride
import net.mullvad.mullvadvpn.lib.endpoint.putApiEndpointConfigurationExtra
import net.mullvad.mullvadvpn.test.common.constant.APP_LAUNCH_TIMEOUT
import net.mullvad.mullvadvpn.test.common.constant.CONNECTION_TIMEOUT
Expand All @@ -23,7 +23,7 @@ class AppInteractor(
private val targetContext: Context,
private val targetPackageName: String,
) {
fun launch(customApiEndpointConfiguration: CustomApiEndpointConfiguration? = null) {
fun launch(customApiEndpointConfiguration: ApiEndpointOverride? = null) {
device.pressHome()
// Wait for launcher
device.wait(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import androidx.test.uiautomator.UiDevice
import co.touchlab.kermit.Logger
import de.mannodermaus.junit5.extensions.GrantPermissionExtension
import java.net.InetAddress
import net.mullvad.mullvadvpn.lib.endpoint.CustomApiEndpointConfiguration
import net.mullvad.mullvadvpn.lib.endpoint.ApiEndpointOverride
import net.mullvad.mullvadvpn.test.common.interactor.AppInteractor
import net.mullvad.mullvadvpn.test.common.rule.CaptureScreenshotOnFailedTestRule
import net.mullvad.mullvadvpn.test.mockapi.constant.LOG_TAG
Expand All @@ -33,7 +33,7 @@ abstract class MockApiTest {
lateinit var device: UiDevice
lateinit var targetContext: Context
lateinit var app: AppInteractor
lateinit var endpoint: CustomApiEndpointConfiguration
lateinit var endpoint: ApiEndpointOverride

@BeforeEach
open fun setup() {
Expand All @@ -54,8 +54,8 @@ abstract class MockApiTest {
mockWebServer.shutdown()
}

private fun createEndpoint(port: Int): CustomApiEndpointConfiguration {
return CustomApiEndpointConfiguration(
private fun createEndpoint(port: Int): ApiEndpointOverride {
return ApiEndpointOverride(
InetAddress.getLocalHost().hostName,
port,
disableAddressCache = true,
Expand Down
Loading

0 comments on commit 9b68474

Please sign in to comment.