Commit 67236f31 authored by Mygod's avatar Mygod

Fix #2405

parent f9eadf26
...@@ -26,14 +26,15 @@ import android.net.DnsResolver ...@@ -26,14 +26,15 @@ import android.net.DnsResolver
import android.net.Network import android.net.Network
import android.os.Build import android.os.Build
import android.os.CancellationSignal import android.os.CancellationSignal
import android.os.Looper
import android.system.ErrnoException import android.system.ErrnoException
import android.system.Os import android.system.Os
import android.system.OsConstants import android.system.OsConstants
import com.github.shadowsocks.Core import com.github.shadowsocks.Core
import com.github.shadowsocks.utils.closeQuietly
import com.github.shadowsocks.utils.int import com.github.shadowsocks.utils.int
import com.github.shadowsocks.utils.parseNumericAddress import com.github.shadowsocks.utils.parseNumericAddress
import com.github.shadowsocks.utils.printLog import com.github.shadowsocks.utils.printLog
import com.github.shadowsocks.utils.use
import kotlinx.coroutines.* import kotlinx.coroutines.*
import java.io.FileDescriptor import java.io.FileDescriptor
import java.io.IOException import java.io.IOException
...@@ -61,12 +62,15 @@ sealed class DnsResolverCompat { ...@@ -61,12 +62,15 @@ sealed class DnsResolverCompat {
*/ */
private val address4 = "8.8.8.8".parseNumericAddress()!! private val address4 = "8.8.8.8".parseNumericAddress()!!
private val address6 = "2000::".parseNumericAddress()!! private val address6 = "2000::".parseNumericAddress()!!
fun haveIpv4(network: Network) = checkConnectivity(network, OsConstants.AF_INET, address4) suspend fun haveIpv4(network: Network) = checkConnectivity(network, OsConstants.AF_INET, address4)
fun haveIpv6(network: Network) = checkConnectivity(network, OsConstants.AF_INET6, address6) suspend fun haveIpv6(network: Network) = checkConnectivity(network, OsConstants.AF_INET6, address6)
private fun checkConnectivity(network: Network, domain: Int, addr: InetAddress) = try { private suspend fun checkConnectivity(network: Network, domain: Int, addr: InetAddress) = try {
Os.socket(domain, OsConstants.SOCK_DGRAM, OsConstants.IPPROTO_UDP).use { socket -> val socket = Os.socket(domain, OsConstants.SOCK_DGRAM, OsConstants.IPPROTO_UDP)
try {
instance.bindSocket(network, socket) instance.bindSocket(network, socket)
Os.connect(socket, addr, 0) instance.connectUdp(socket, addr)
} finally {
socket.closeQuietly()
} }
true true
} catch (_: IOException) { } catch (_: IOException) {
...@@ -93,6 +97,8 @@ sealed class DnsResolverCompat { ...@@ -93,6 +97,8 @@ sealed class DnsResolverCompat {
@Throws(IOException::class) @Throws(IOException::class)
abstract fun bindSocket(network: Network, socket: FileDescriptor) abstract fun bindSocket(network: Network, socket: FileDescriptor)
internal open suspend fun connectUdp(fd: FileDescriptor, address: InetAddress, port: Int = 0) =
Os.connect(fd, address, port)
abstract suspend fun resolve(network: Network, host: String): Array<InetAddress> abstract suspend fun resolve(network: Network, host: String): Array<InetAddress>
abstract suspend fun resolveOnActiveNetwork(host: String): Array<InetAddress> abstract suspend fun resolveOnActiveNetwork(host: String): Array<InetAddress>
...@@ -110,6 +116,12 @@ sealed class DnsResolverCompat { ...@@ -110,6 +116,12 @@ sealed class DnsResolverCompat {
throw IOException(message, ErrnoException(message, -err)) throw IOException(message, ErrnoException(message, -err))
} }
override suspend fun connectUdp(fd: FileDescriptor, address: InetAddress, port: Int) {
if (Looper.getMainLooper().thread == Thread.currentThread()) withContext(Dispatchers.IO) { // #2405
super.connectUdp(fd, address, port)
} else super.connectUdp(fd, address, port)
}
/** /**
* This dispatcher is used for noncancellable possibly-forever-blocking operations in network IO. * This dispatcher is used for noncancellable possibly-forever-blocking operations in network IO.
* *
......
...@@ -39,9 +39,9 @@ import com.github.shadowsocks.net.HostsFile ...@@ -39,9 +39,9 @@ import com.github.shadowsocks.net.HostsFile
import com.github.shadowsocks.net.Subnet import com.github.shadowsocks.net.Subnet
import com.github.shadowsocks.preference.DataStore import com.github.shadowsocks.preference.DataStore
import com.github.shadowsocks.utils.Key import com.github.shadowsocks.utils.Key
import com.github.shadowsocks.utils.closeQuietly
import com.github.shadowsocks.utils.int import com.github.shadowsocks.utils.int
import com.github.shadowsocks.utils.printLog import com.github.shadowsocks.utils.printLog
import com.github.shadowsocks.utils.use
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
...@@ -65,7 +65,8 @@ class VpnService : BaseVpnService(), LocalDnsService.Interface { ...@@ -65,7 +65,8 @@ class VpnService : BaseVpnService(), LocalDnsService.Interface {
File(Core.deviceStorage.noBackupFilesDir, "protect_path")) { File(Core.deviceStorage.noBackupFilesDir, "protect_path")) {
override fun acceptInternal(socket: LocalSocket) { override fun acceptInternal(socket: LocalSocket) {
socket.inputStream.read() socket.inputStream.read()
socket.ancillaryFileDescriptors!!.single()!!.use { fd -> val fd = socket.ancillaryFileDescriptors!!.single()!!
try {
socket.outputStream.write(if (underlyingNetwork.let { network -> socket.outputStream.write(if (underlyingNetwork.let { network ->
if (network != null) try { if (network != null) try {
DnsResolverCompat.bindSocket(network, fd) DnsResolverCompat.bindSocket(network, fd)
...@@ -80,6 +81,8 @@ class VpnService : BaseVpnService(), LocalDnsService.Interface { ...@@ -80,6 +81,8 @@ class VpnService : BaseVpnService(), LocalDnsService.Interface {
} }
protect(fd.int) protect(fd.int)
}) 0 else 1) }) 0 else 1)
} finally {
fd.closeQuietly()
} }
} }
} }
......
...@@ -66,13 +66,9 @@ val Throwable.readableMessage get() = localizedMessage ?: javaClass.name ...@@ -66,13 +66,9 @@ val Throwable.readableMessage get() = localizedMessage ?: javaClass.name
private val getInt = FileDescriptor::class.java.getDeclaredMethod("getInt$") private val getInt = FileDescriptor::class.java.getDeclaredMethod("getInt$")
val FileDescriptor.int get() = getInt.invoke(this) as Int val FileDescriptor.int get() = getInt.invoke(this) as Int
fun <T> FileDescriptor.use(block: (FileDescriptor) -> T) = try { fun FileDescriptor.closeQuietly() = try {
block(this) Os.close(this)
} finally { } catch (_: ErrnoException) { }
try {
Os.close(this)
} catch (_: ErrnoException) { }
}
private val parseNumericAddress by lazy @SuppressLint("DiscouragedPrivateApi") { private val parseNumericAddress by lazy @SuppressLint("DiscouragedPrivateApi") {
InetAddress::class.java.getDeclaredMethod("parseNumericAddress", String::class.java).apply { InetAddress::class.java.getDeclaredMethod("parseNumericAddress", String::class.java).apply {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment