Skip to content

Commit

Permalink
Replace old waitForTunnelUp function
Browse files Browse the repository at this point in the history
After invoking VpnService.establish() we will get a tunnel file
descriptor that corresponds to the interface that was created. However,
this has no guarantee of the routing table beeing up to date, and we
might thus send traffic outside the tunnel. Previously this was done
through looking at the tunFd to see that traffic is sent to verify that
the routing table has changed. If no traffic is seen some traffic is
induced to a random IP address to ensure traffic can be seen. This new
implementation is slower but won't risk sending UDP traffic to a random
public address at the internet.
  • Loading branch information
Rawa committed Feb 6, 2025
1 parent 612aad8 commit 341c10b
Show file tree
Hide file tree
Showing 24 changed files with 748 additions and 412 deletions.
25 changes: 13 additions & 12 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
package net.mullvad.talpid

import android.net.VpnService
import android.os.ParcelFileDescriptor
import arrow.core.right
import io.mockk.MockKAnnotations
import io.mockk.coVerify
import io.mockk.every
import io.mockk.mockk
import io.mockk.mockkConstructor
import io.mockk.mockkStatic
import io.mockk.spyk
import java.net.InetAddress
import net.mullvad.mullvadvpn.lib.common.test.assertLists
import net.mullvad.mullvadvpn.lib.common.util.prepareVpnSafe
import net.mullvad.mullvadvpn.lib.model.Prepared
import net.mullvad.talpid.model.CreateTunResult
import net.mullvad.talpid.model.InetNetwork
import net.mullvad.talpid.model.TunConfig
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertInstanceOf

class TalpidVpnServiceFallbackDnsTest {
lateinit var talpidVpnService: TalpidVpnService
var builderMockk = mockk<VpnService.Builder>()

@BeforeEach
fun setup() {
MockKAnnotations.init(this)
mockkStatic(VPN_SERVICE_EXTENSION)

talpidVpnService = spyk<TalpidVpnService>(recordPrivateCalls = true)
every { talpidVpnService.prepareVpnSafe() } returns Prepared.right()
builderMockk = mockk<VpnService.Builder>()

mockkConstructor(VpnService.Builder::class)
every { anyConstructed<VpnService.Builder>().setMtu(any()) } returns builderMockk
every { anyConstructed<VpnService.Builder>().setBlocking(any()) } returns builderMockk
every { anyConstructed<VpnService.Builder>().addAddress(any<InetAddress>(), any()) } returns
builderMockk
every { anyConstructed<VpnService.Builder>().addRoute(any<InetAddress>(), any()) } returns
builderMockk
every {
anyConstructed<VpnService.Builder>()
.addDnsServer(TalpidVpnService.FALLBACK_DUMMY_DNS_SERVER)
} returns builderMockk
val parcelFileDescriptor: ParcelFileDescriptor = mockk()
every { anyConstructed<VpnService.Builder>().establish() } returns parcelFileDescriptor
every { parcelFileDescriptor.detachFd() } returns 1
}

@Test
fun `opening tun with no DnsServers should add fallback DNS server`() {
val tunConfig = baseTunConfig.copy(dnsServers = arrayListOf())

val result = talpidVpnService.openTun(tunConfig)

assertInstanceOf<CreateTunResult.Success>(result)

// Fallback DNS server should be added if no DNS servers are provided
coVerify(exactly = 1) {
anyConstructed<VpnService.Builder>()
.addDnsServer(TalpidVpnService.FALLBACK_DUMMY_DNS_SERVER)
}
}

@Test
fun `opening tun with all bad DnsServers should return InvalidDnsServers and add fallback`() {
val badDns1 = InetAddress.getByName("0.0.0.0")
val badDns2 = InetAddress.getByName("255.255.255.255")
every { anyConstructed<VpnService.Builder>().addDnsServer(badDns1) } throws
IllegalArgumentException()
every { anyConstructed<VpnService.Builder>().addDnsServer(badDns2) } throws
IllegalArgumentException()

val tunConfig = baseTunConfig.copy(dnsServers = arrayListOf(badDns1, badDns2))
val result = talpidVpnService.openTun(tunConfig)

assertInstanceOf<CreateTunResult.InvalidDnsServers>(result)
assertLists(tunConfig.dnsServers, result.addresses)
// Fallback DNS server should be added if no valid DNS servers are provided
coVerify(exactly = 1) {
anyConstructed<VpnService.Builder>()
.addDnsServer(TalpidVpnService.FALLBACK_DUMMY_DNS_SERVER)
}
}

@Test
fun `opening tun with 1 good and 1 bad DnsServers should return InvalidDnsServers`() {
val goodDnsServer = InetAddress.getByName("1.1.1.1")
val badDns = InetAddress.getByName("255.255.255.255")
every { anyConstructed<VpnService.Builder>().addDnsServer(goodDnsServer) } returns
builderMockk
every { anyConstructed<VpnService.Builder>().addDnsServer(badDns) } throws
IllegalArgumentException()

val tunConfig = baseTunConfig.copy(dnsServers = arrayListOf(goodDnsServer, badDns))
val result = talpidVpnService.openTun(tunConfig)

assertInstanceOf<CreateTunResult.InvalidDnsServers>(result)
assertLists(arrayListOf(badDns), result.addresses)

// Fallback DNS server should not be added since we have 1 good DNS server
coVerify(exactly = 0) {
anyConstructed<VpnService.Builder>()
.addDnsServer(TalpidVpnService.FALLBACK_DUMMY_DNS_SERVER)
}
}

@Test
fun `providing good dns servers should not add the fallback dns and return success`() {
val goodDnsServer = InetAddress.getByName("1.1.1.1")
every { anyConstructed<VpnService.Builder>().addDnsServer(goodDnsServer) } returns
builderMockk

val tunConfig = baseTunConfig.copy(dnsServers = arrayListOf(goodDnsServer))
val result = talpidVpnService.openTun(tunConfig)

assertInstanceOf<CreateTunResult.Success>(result)

// Fallback DNS server should not be added since we have good DNS servers.
coVerify(exactly = 0) {
anyConstructed<VpnService.Builder>()
.addDnsServer(TalpidVpnService.FALLBACK_DUMMY_DNS_SERVER)
}
}

companion object {
private const val VPN_SERVICE_EXTENSION =
"net.mullvad.mullvadvpn.lib.common.util.VpnServiceUtilsKt"

val baseTunConfig =
TunConfig(
addresses = arrayListOf(InetAddress.getByName("45.83.223.209")),
dnsServers = arrayListOf(),
routes =
arrayListOf(
InetNetwork(InetAddress.getByName("0.0.0.0"), 0),
InetNetwork(InetAddress.getByName("::"), 0),
),
mtu = 1280,
excludedPackages = arrayListOf(),
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,23 @@ package net.mullvad.mullvadvpn.lib.common.util

import android.content.Context
import android.content.Intent
import android.net.VpnService
import android.net.VpnService.prepare
import android.os.ParcelFileDescriptor
import arrow.core.Either
import arrow.core.flatten
import arrow.core.flatMap
import arrow.core.left
import arrow.core.raise.either
import arrow.core.raise.ensureNotNull
import arrow.core.right
import co.touchlab.kermit.Logger
import net.mullvad.mullvadvpn.lib.common.util.SdkUtils.getInstalledPackagesList
import net.mullvad.mullvadvpn.lib.model.PrepareError
import net.mullvad.mullvadvpn.lib.model.Prepared

/**
* Prepare to establish a VPN connection safely.
*
* Invoking VpnService.prepare() can result in 3 out comes:
* 1. IllegalStateException - There is a legacy VPN profile marked as always on
* 2. Intent
Expand All @@ -34,7 +40,7 @@ fun Context.prepareVpnSafe(): Either<PrepareError, Prepared> =
else -> throw it
}
}
.map { intent ->
.flatMap { intent ->
if (intent == null) {
Prepared.right()
} else {
Expand All @@ -46,7 +52,6 @@ fun Context.prepareVpnSafe(): Either<PrepareError, Prepared> =
}
}
}
.flatten()

fun Context.getAlwaysOnVpnAppName(): String? {
return resolveAlwaysOnVpnPackageName()
Expand All @@ -59,3 +64,38 @@ fun Context.getAlwaysOnVpnAppName(): String? {
?.loadLabel(packageManager)
?.toString()
}

/**
* Establish a VPN connection safely.
*
* This function wraps the [VpnService.Builder.establish] function and catches any exceptions that
* may be thrown and type them to a more specific error.
*
* @return [ParcelFileDescriptor] if successful, [EstablishError] otherwise
*/
fun VpnService.Builder.establishSafe(): Either<EstablishError, ParcelFileDescriptor> = either {
val vpnInterfaceFd =
Either.catch { establish() }
.mapLeft {
when (it) {
is IllegalStateException -> EstablishError.ParameterNotApplied(it)
is IllegalArgumentException -> EstablishError.ParameterNotAccepted(it)
else -> EstablishError.UnknownError(it)
}
}
.bind()

ensureNotNull(vpnInterfaceFd) { EstablishError.NullVpnInterface }

vpnInterfaceFd
}

sealed interface EstablishError {
data class ParameterNotApplied(val exception: IllegalStateException) : EstablishError

data class ParameterNotAccepted(val exception: IllegalArgumentException) : EstablishError

data object NullVpnInterface : EstablishError

data class UnknownError(val error: Throwable) : EstablishError
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@ import net.mullvad.mullvadvpn.lib.model.DnsState
import net.mullvad.mullvadvpn.lib.model.Endpoint
import net.mullvad.mullvadvpn.lib.model.ErrorState
import net.mullvad.mullvadvpn.lib.model.ErrorStateCause
import net.mullvad.mullvadvpn.lib.model.ErrorStateCause.AuthFailed
import net.mullvad.mullvadvpn.lib.model.ErrorStateCause.OtherAlwaysOnApp
import net.mullvad.mullvadvpn.lib.model.ErrorStateCause.TunnelParameterError
import net.mullvad.mullvadvpn.lib.model.FeatureIndicator
import net.mullvad.mullvadvpn.lib.model.GeoIpLocation
import net.mullvad.mullvadvpn.lib.model.GeoLocationId
Expand Down Expand Up @@ -125,7 +122,7 @@ private fun ManagementInterface.TunnelState.Error.toDomain(): TunnelState.Error
val otherAlwaysOnAppError =
errorState.let {
if (it.hasOtherAlwaysOnAppError()) {
OtherAlwaysOnApp(it.otherAlwaysOnAppError.appName)
ErrorStateCause.OtherAlwaysOnApp(it.otherAlwaysOnAppError.appName)
} else {
null
}
Expand Down Expand Up @@ -238,7 +235,7 @@ internal fun ManagementInterface.ErrorState.toDomain(
cause =
when (cause!!) {
ManagementInterface.ErrorState.Cause.AUTH_FAILED ->
AuthFailed(authFailedError.toDomain())
ErrorStateCause.AuthFailed(authFailedError.toDomain())
ManagementInterface.ErrorState.Cause.IPV6_UNAVAILABLE ->
ErrorStateCause.Ipv6Unavailable
ManagementInterface.ErrorState.Cause.SET_FIREWALL_POLICY_ERROR ->
Expand All @@ -247,15 +244,14 @@ internal fun ManagementInterface.ErrorState.toDomain(
ManagementInterface.ErrorState.Cause.START_TUNNEL_ERROR ->
ErrorStateCause.StartTunnelError
ManagementInterface.ErrorState.Cause.TUNNEL_PARAMETER_ERROR ->
TunnelParameterError(parameterError.toDomain())
ErrorStateCause.TunnelParameterError(parameterError.toDomain())
ManagementInterface.ErrorState.Cause.IS_OFFLINE -> ErrorStateCause.IsOffline
ManagementInterface.ErrorState.Cause.SPLIT_TUNNEL_ERROR ->
ErrorStateCause.StartTunnelError
ManagementInterface.ErrorState.Cause.UNRECOGNIZED,
ManagementInterface.ErrorState.Cause.NEED_FULL_DISK_PERMISSIONS,
ManagementInterface.ErrorState.Cause.CREATE_TUNNEL_DEVICE ->
throw IllegalArgumentException("Unrecognized error state cause")

ManagementInterface.ErrorState.Cause.NOT_PREPARED -> ErrorStateCause.NotPrepared
ManagementInterface.ErrorState.Cause.OTHER_ALWAYS_ON_APP -> otherAlwaysOnApp!!
ManagementInterface.ErrorState.Cause.OTHER_LEGACY_ALWAYS_ON_VPN ->
Expand Down
Loading

0 comments on commit 341c10b

Please sign in to comment.