Commit 847149c2 authored by Mygod's avatar Mygod

Fix DNS resolving timeouts due to race conditions

parent cabe12e8
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
<uses-permission android:name="android.permission.INTERNET"/> <uses-permission android:name="android.permission.INTERNET"/>
<uses-permission android:name="android.permission.RECEIVE_BOOT_COMPLETED"/> <uses-permission android:name="android.permission.RECEIVE_BOOT_COMPLETED"/>
<uses-permission android:name="android.permission.WAKE_LOCK"/> <uses-permission android:name="android.permission.WAKE_LOCK"/>
<!-- This permission is only used on Android 6.0 due to bug: https://stackoverflow.com/a/33509180/2245107 -->
<uses-permission-sdk-23 android:name="android.permission.WRITE_SETTINGS"/>
<uses-feature android:name="android.software.leanback" <uses-feature android:name="android.software.leanback"
android:required="false"/> android:required="false"/>
......
...@@ -42,6 +42,7 @@ import com.github.shadowsocks.utils.printLog ...@@ -42,6 +42,7 @@ import com.github.shadowsocks.utils.printLog
import com.google.firebase.analytics.FirebaseAnalytics import com.google.firebase.analytics.FirebaseAnalytics
import kotlinx.coroutines.* import kotlinx.coroutines.*
import java.io.File import java.io.File
import java.net.InetAddress
import java.net.UnknownHostException import java.net.UnknownHostException
import java.util.* import java.util.*
...@@ -274,6 +275,9 @@ object BaseService { ...@@ -274,6 +275,9 @@ object BaseService {
} }
} }
suspend fun preInit() { }
suspend fun resolver(host: String) = InetAddress.getByName(host)
fun onStartCommand(intent: Intent?, flags: Int, startId: Int): Int { fun onStartCommand(intent: Intent?, flags: Int, startId: Int): Int {
val data = data val data = data
if (data.state != STOPPED) return Service.START_NOT_STICKY if (data.state != STOPPED) return Service.START_NOT_STICKY
...@@ -306,8 +310,9 @@ object BaseService { ...@@ -306,8 +310,9 @@ object BaseService {
data.changeState(CONNECTING) data.changeState(CONNECTING)
data.connectingJob = GlobalScope.launch(Dispatchers.Main) { data.connectingJob = GlobalScope.launch(Dispatchers.Main) {
try { try {
proxy.init() preInit()
data.udpFallback?.init() proxy.init(this@Interface::resolver)
data.udpFallback?.init(this@Interface::resolver)
killProcesses() killProcesses()
data.processes = GuardedProcessPool { data.processes = GuardedProcessPool {
......
/*******************************************************************************
* *
* Copyright (C) 2019 by Max Lv <max.c.lv@gmail.com> *
* Copyright (C) 2019 by Mygod Studio <contact-shadowsocks-android@mygod.be> *
* *
* This program is free software: you can redistribute it and/or modify *
* it under the terms of the GNU General Public License as published by *
* the Free Software Foundation, either version 3 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU General Public License for more details. *
* *
* You should have received a copy of the GNU General Public License *
* along with this program. If not, see <http://www.gnu.org/licenses/>. *
* *
*******************************************************************************/
package com.github.shadowsocks.bg
import android.annotation.TargetApi
import android.net.ConnectivityManager
import android.net.Network
import android.net.NetworkCapabilities
import android.net.NetworkRequest
import android.os.Build
import android.widget.Toast
import androidx.core.content.getSystemService
import com.crashlytics.android.Crashlytics
import com.github.shadowsocks.Core.app
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.channels.actor
import kotlinx.coroutines.runBlocking
object DefaultNetworkListener : CoroutineScope {
override val coroutineContext get() = Dispatchers.Default
private sealed class NetworkMessage {
class Start(val key: Any, val listener: (Network?) -> Unit) : NetworkMessage()
class Get : NetworkMessage() {
val response = CompletableDeferred<Network>()
}
class Stop(val key: Any) : NetworkMessage()
class Put(val network: Network) : NetworkMessage()
class Update(val network: Network) : NetworkMessage()
class Lost(val network: Network) : NetworkMessage()
}
private val networkActor = actor<NetworkMessage> {
val listeners = mutableMapOf<Any, (Network?) -> Unit>()
var network: Network? = null
val pendingRequests = arrayListOf<NetworkMessage.Get>()
for (message in channel) when (message) {
is NetworkMessage.Start -> {
if (listeners.isEmpty()) registerDefaultNetworkListener()
listeners[message.key] = message.listener
if (network != null) message.listener(network)
}
is NetworkMessage.Get -> {
check(listeners.isNotEmpty()) { "Getting network without any listeners is not supported" }
if (network == null) pendingRequests += message else message.response.complete(network)
}
is NetworkMessage.Stop -> {
if (!listeners.isEmpty() && // was not empty
listeners.remove(message.key) != null && listeners.isEmpty()) unregisterDefaultNetworkListener()
}
is NetworkMessage.Put -> {
network = message.network
pendingRequests.forEach { it.response.complete(message.network) }
pendingRequests.clear()
listeners.values.forEach { it(network) }
}
is NetworkMessage.Update -> if (network == message.network) listeners.values.forEach { it(network) }
is NetworkMessage.Lost -> if (network == message.network) {
network = null
listeners.values.forEach { it(null) }
}
}
}
suspend fun start(key: Any, listener: (Network?) -> Unit) = networkActor.send(NetworkMessage.Start(key, listener))
suspend fun get() = NetworkMessage.Get().run {
networkActor.send(this)
response.await()
}
suspend fun stop(key: Any) = networkActor.send(NetworkMessage.Stop(key))
// NB: this runs in ConnectivityThread, and this behavior cannot be changed until API 26
private object Callback : ConnectivityManager.NetworkCallback() {
override fun onAvailable(network: Network) = runBlocking { networkActor.send(NetworkMessage.Put(network)) }
override fun onCapabilitiesChanged(network: Network, networkCapabilities: NetworkCapabilities?) {
// it's a good idea to refresh capabilities
runBlocking { networkActor.send(NetworkMessage.Update(network)) }
}
override fun onLost(network: Network) = runBlocking { networkActor.send(NetworkMessage.Lost(network)) }
}
private val connectivity = app.getSystemService<ConnectivityManager>()!!
private val defaultNetworkRequest = NetworkRequest.Builder()
.addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
.addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_RESTRICTED)
.build()
/**
* Unfortunately registerDefaultNetworkCallback is going to return VPN interface since Android P DP1:
* https://android.googlesource.com/platform/frameworks/base/+/dda156ab0c5d66ad82bdcf76cda07cbc0a9c8a2e
*
* This makes doing a requestNetwork with REQUEST necessary so that we don't get ALL possible networks that
* satisfies default network capabilities but only THE default network. Unfortunately, we need to have
* android.permission.CHANGE_NETWORK_STATE to be able to call requestNetwork.
*
* Source: https://android.googlesource.com/platform/frameworks/base/+/2df4c7d/services/core/java/com/android/server/ConnectivityService.java#887
*/
private fun registerDefaultNetworkListener() {
if (Build.VERSION.SDK_INT in 24..27) @TargetApi(24) {
connectivity.registerDefaultNetworkCallback(Callback)
} else try {
// we want REQUEST here instead of LISTEN
connectivity.requestNetwork(defaultNetworkRequest, Callback)
} catch (e: SecurityException) {
// known bug: https://stackoverflow.com/a/33509180/2245107
if (Build.VERSION.SDK_INT != 23) Crashlytics.logException(e)
Toast.makeText(app, e.localizedMessage, Toast.LENGTH_SHORT).show()
connectivity.registerNetworkCallback(defaultNetworkRequest, Callback)
}
}
private fun unregisterDefaultNetworkListener() = connectivity.unregisterNetworkCallback(Callback)
}
...@@ -116,7 +116,7 @@ class GuardedProcessPool(private val onFatal: (IOException) -> Unit) : Coroutine ...@@ -116,7 +116,7 @@ class GuardedProcessPool(private val onFatal: (IOException) -> Unit) : Coroutine
private val guards = ArrayList<Guard>() private val guards = ArrayList<Guard>()
@MainThread @MainThread
suspend fun start(cmd: List<String>, onRestartCallback: (suspend () -> Unit)? = null) { fun start(cmd: List<String>, onRestartCallback: (suspend () -> Unit)? = null) {
Crashlytics.log(Log.DEBUG, TAG, "start process: " + Commandline.toString(cmd)) Crashlytics.log(Log.DEBUG, TAG, "start process: " + Commandline.toString(cmd))
val guard = Guard(cmd) val guard = Guard(cmd)
guard.start() guard.start()
......
...@@ -52,7 +52,7 @@ class ProxyInstance(val profile: Profile, private val route: String = profile.ro ...@@ -52,7 +52,7 @@ class ProxyInstance(val profile: Profile, private val route: String = profile.ro
private val plugin = PluginConfiguration(profile.plugin ?: "").selectedOptions private val plugin = PluginConfiguration(profile.plugin ?: "").selectedOptions
val pluginPath by lazy { PluginManager.init(plugin) } val pluginPath by lazy { PluginManager.init(plugin) }
suspend fun init() { suspend fun init(resolver: suspend (String) -> InetAddress) {
if (profile.host == "198.199.101.152") { if (profile.host == "198.199.101.152") {
val mdg = MessageDigest.getInstance("SHA-1") val mdg = MessageDigest.getInstance("SHA-1")
mdg.update(Core.packageInfo.signaturesCompat.first().toByteArray()) mdg.update(Core.packageInfo.signaturesCompat.first().toByteArray())
...@@ -84,7 +84,7 @@ class ProxyInstance(val profile: Profile, private val route: String = profile.ro ...@@ -84,7 +84,7 @@ class ProxyInstance(val profile: Profile, private val route: String = profile.ro
// it's hard to resolve DNS on a specific interface so we'll do it here // it's hard to resolve DNS on a specific interface so we'll do it here
if (profile.host.parseNumericAddress() == null) profile.host = withTimeout(10_000) { if (profile.host.parseNumericAddress() == null) profile.host = withTimeout(10_000) {
withContext(Dispatchers.IO) { InetAddress.getByName(profile.host).hostAddress } withContext(Dispatchers.IO) { resolver(profile.host).hostAddress }
} ?: throw UnknownHostException() } ?: throw UnknownHostException()
} }
......
...@@ -20,16 +20,16 @@ ...@@ -20,16 +20,16 @@
package com.github.shadowsocks.bg package com.github.shadowsocks.bg
import android.annotation.TargetApi
import android.app.Service import android.app.Service
import android.content.Intent import android.content.Intent
import android.content.pm.PackageManager import android.content.pm.PackageManager
import android.net.* import android.net.LocalSocket
import android.net.LocalSocketAddress
import android.net.Network
import android.os.Build import android.os.Build
import android.os.ParcelFileDescriptor import android.os.ParcelFileDescriptor
import android.system.ErrnoException import android.system.ErrnoException
import android.system.Os import android.system.Os
import androidx.core.content.getSystemService
import com.github.shadowsocks.Core import com.github.shadowsocks.Core
import com.github.shadowsocks.VpnRequestActivity import com.github.shadowsocks.VpnRequestActivity
import com.github.shadowsocks.acl.Acl import com.github.shadowsocks.acl.Acl
...@@ -59,20 +59,6 @@ class VpnService : BaseVpnService(), LocalDnsService.Interface { ...@@ -59,20 +59,6 @@ class VpnService : BaseVpnService(), LocalDnsService.Interface {
*/ */
private val getInt: Method = FileDescriptor::class.java.getDeclaredMethod("getInt$") private val getInt: Method = FileDescriptor::class.java.getDeclaredMethod("getInt$")
/**
* Unfortunately registerDefaultNetworkCallback is going to return VPN interface since Android P DP1:
* https://android.googlesource.com/platform/frameworks/base/+/dda156ab0c5d66ad82bdcf76cda07cbc0a9c8a2e
*
* This makes doing a requestNetwork with REQUEST necessary so that we don't get ALL possible networks that
* satisfies default network capabilities but only THE default network. Unfortunately we need to have
* android.permission.CHANGE_NETWORK_STATE to be able to call requestNetwork.
*
* Source: https://android.googlesource.com/platform/frameworks/base/+/2df4c7d/services/core/java/com/android/server/ConnectivityService.java#887
*/
private val defaultNetworkRequest = NetworkRequest.Builder()
.addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
.addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_RESTRICTED)
.build()
} }
class CloseableFd(val fd: FileDescriptor) : Closeable { class CloseableFd(val fd: FileDescriptor) : Closeable {
...@@ -123,28 +109,13 @@ class VpnService : BaseVpnService(), LocalDnsService.Interface { ...@@ -123,28 +109,13 @@ class VpnService : BaseVpnService(), LocalDnsService.Interface {
private var conn: ParcelFileDescriptor? = null private var conn: ParcelFileDescriptor? = null
private var worker: ProtectWorker? = null private var worker: ProtectWorker? = null
private var active = false
private var underlyingNetwork: Network? = null private var underlyingNetwork: Network? = null
@TargetApi(24)
set(value) { set(value) {
setUnderlyingNetworks(if (value == null) null else arrayOf(value))
field = value field = value
if (active && Build.VERSION.SDK_INT >= 22) setUnderlyingNetworks(underlyingNetworks)
} }
private val underlyingNetworks = underlyingNetwork?.let { arrayOf(it) }
private val connectivity by lazy { getSystemService<ConnectivityManager>()!! }
@TargetApi(24)
private val defaultNetworkCallback = object : ConnectivityManager.NetworkCallback() {
override fun onAvailable(network: Network) {
underlyingNetwork = network
}
override fun onCapabilitiesChanged(network: Network, networkCapabilities: NetworkCapabilities?) {
// it's a good idea to refresh capabilities
underlyingNetwork = network
}
override fun onLost(network: Network) {
underlyingNetwork = null
}
}
private var listeningForDefaultNetwork = false
override fun onBind(intent: Intent) = when (intent.action) { override fun onBind(intent: Intent) = when (intent.action) {
SERVICE_INTERFACE -> super<BaseVpnService>.onBind(intent) SERVICE_INTERFACE -> super<BaseVpnService>.onBind(intent)
...@@ -154,10 +125,8 @@ class VpnService : BaseVpnService(), LocalDnsService.Interface { ...@@ -154,10 +125,8 @@ class VpnService : BaseVpnService(), LocalDnsService.Interface {
override fun onRevoke() = stopRunner() override fun onRevoke() = stopRunner()
override suspend fun killProcesses() { override suspend fun killProcesses() {
if (listeningForDefaultNetwork) { active = false
connectivity.unregisterNetworkCallback(defaultNetworkCallback) DefaultNetworkListener.stop(this)
listeningForDefaultNetwork = false
}
worker?.shutdown() worker?.shutdown()
worker = null worker = null
super.killProcesses() super.killProcesses()
...@@ -175,6 +144,9 @@ class VpnService : BaseVpnService(), LocalDnsService.Interface { ...@@ -175,6 +144,9 @@ class VpnService : BaseVpnService(), LocalDnsService.Interface {
return Service.START_NOT_STICKY return Service.START_NOT_STICKY
} }
override suspend fun preInit() = DefaultNetworkListener.start(this) { underlyingNetwork = it }
override suspend fun resolver(host: String) = DefaultNetworkListener.get().getByName(host)
override suspend fun startProcesses() { override suspend fun startProcesses() {
worker = ProtectWorker().apply { start() } worker = ProtectWorker().apply { start() }
...@@ -230,16 +202,13 @@ class VpnService : BaseVpnService(), LocalDnsService.Interface { ...@@ -230,16 +202,13 @@ class VpnService : BaseVpnService(), LocalDnsService.Interface {
} }
} }
active = true // possible race condition here?
if (Build.VERSION.SDK_INT >= 22) builder.setUnderlyingNetworks(underlyingNetworks)
val conn = builder.establish() ?: throw NullConnectionException() val conn = builder.establish() ?: throw NullConnectionException()
this.conn = conn this.conn = conn
val fd = conn.fd val fd = conn.fd
if (Build.VERSION.SDK_INT >= 24) {
// we want REQUEST here instead of LISTEN
connectivity.requestNetwork(defaultNetworkRequest, defaultNetworkCallback)
listeningForDefaultNetwork = true
}
val cmd = arrayListOf(File(applicationInfo.nativeLibraryDir, Executable.TUN2SOCKS).absolutePath, val cmd = arrayListOf(File(applicationInfo.nativeLibraryDir, Executable.TUN2SOCKS).absolutePath,
"--netif-ipaddr", PRIVATE_VLAN.format(Locale.ENGLISH, "2"), "--netif-ipaddr", PRIVATE_VLAN.format(Locale.ENGLISH, "2"),
"--netif-netmask", "255.255.255.0", "--netif-netmask", "255.255.255.0",
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment