Commit 347d00fb authored by Mygod's avatar Mygod Committed by Max Lv

Add draft impl for local_dns_path

parent ad868f54
...@@ -31,13 +31,19 @@ import android.system.ErrnoException ...@@ -31,13 +31,19 @@ import android.system.ErrnoException
import android.system.Os import android.system.Os
import android.system.OsConstants import android.system.OsConstants
import com.github.shadowsocks.Core import com.github.shadowsocks.Core
import com.github.shadowsocks.net.LocalDnsServer
import com.github.shadowsocks.utils.closeQuietly import com.github.shadowsocks.utils.closeQuietly
import com.github.shadowsocks.utils.int import com.github.shadowsocks.utils.int
import com.github.shadowsocks.utils.parseNumericAddress import com.github.shadowsocks.utils.parseNumericAddress
import com.github.shadowsocks.utils.printLog import com.github.shadowsocks.utils.printLog
import kotlinx.coroutines.* import kotlinx.coroutines.*
import org.xbill.DNS.Message
import org.xbill.DNS.Opcode
import org.xbill.DNS.Type
import java.io.FileDescriptor import java.io.FileDescriptor
import java.io.IOException import java.io.IOException
import java.net.Inet4Address
import java.net.Inet6Address
import java.net.InetAddress import java.net.InetAddress
import java.util.concurrent.Executor import java.util.concurrent.Executor
import java.util.concurrent.Executors import java.util.concurrent.Executors
...@@ -94,6 +100,7 @@ sealed class DnsResolverCompat { ...@@ -94,6 +100,7 @@ sealed class DnsResolverCompat {
override fun bindSocket(network: Network, socket: FileDescriptor) = instance.bindSocket(network, socket) override fun bindSocket(network: Network, socket: FileDescriptor) = instance.bindSocket(network, socket)
override suspend fun resolve(network: Network, host: String) = instance.resolve(network, host) override suspend fun resolve(network: Network, host: String) = instance.resolve(network, host)
override suspend fun resolveOnActiveNetwork(host: String) = instance.resolveOnActiveNetwork(host) override suspend fun resolveOnActiveNetwork(host: String) = instance.resolveOnActiveNetwork(host)
override suspend fun resolveRaw(network: Network, query: ByteArray) = instance.resolveRaw(network, query)
} }
@Throws(IOException::class) @Throws(IOException::class)
...@@ -102,6 +109,7 @@ sealed class DnsResolverCompat { ...@@ -102,6 +109,7 @@ sealed class DnsResolverCompat {
Os.connect(fd, address, port) Os.connect(fd, address, port)
abstract suspend fun resolve(network: Network, host: String): Array<InetAddress> abstract suspend fun resolve(network: Network, host: String): Array<InetAddress>
abstract suspend fun resolveOnActiveNetwork(host: String): Array<InetAddress> abstract suspend fun resolveOnActiveNetwork(host: String): Array<InetAddress>
abstract suspend fun resolveRaw(network: Network, query: ByteArray): ByteArray
@SuppressLint("PrivateApi") @SuppressLint("PrivateApi")
private open class DnsResolverCompat21 : DnsResolverCompat() { private open class DnsResolverCompat21 : DnsResolverCompat() {
...@@ -138,6 +146,28 @@ sealed class DnsResolverCompat { ...@@ -138,6 +146,28 @@ sealed class DnsResolverCompat {
GlobalScope.async(unboundedIO) { network.getAllByName(host) }.await() GlobalScope.async(unboundedIO) { network.getAllByName(host) }.await()
override suspend fun resolveOnActiveNetwork(host: String) = override suspend fun resolveOnActiveNetwork(host: String) =
GlobalScope.async(unboundedIO) { InetAddress.getAllByName(host) }.await() GlobalScope.async(unboundedIO) { InetAddress.getAllByName(host) }.await()
override suspend fun resolveRaw(network: Network, query: ByteArray): ByteArray {
val request = try {
Message(query)
} catch (e: IOException) {
throw UnsupportedOperationException(e) // unrecognized packet
}
when (val opcode = request.header.opcode) {
Opcode.QUERY -> { }
else -> throw UnsupportedOperationException("Unsupported opcode $opcode")
}
val question = request.question
val isIpv6 = when (val type = question?.type) {
Type.A -> false
Type.AAAA -> true
else -> throw UnsupportedOperationException("Unsupported query type $type")
}
val host = question.name.canonicalize().toString(true)
return LocalDnsServer.cookDnsResponse(request, resolve(network, host).asIterable().run {
if (isIpv6) filterIsInstance<Inet6Address>() else filterIsInstance<Inet4Address>()
})
}
} }
@TargetApi(23) @TargetApi(23)
...@@ -171,5 +201,17 @@ sealed class DnsResolverCompat { ...@@ -171,5 +201,17 @@ sealed class DnsResolverCompat {
override suspend fun resolveOnActiveNetwork(host: String): Array<InetAddress> { override suspend fun resolveOnActiveNetwork(host: String): Array<InetAddress> {
return resolve(Core.connectivity.activeNetwork ?: return emptyArray(), host) return resolve(Core.connectivity.activeNetwork ?: return emptyArray(), host)
} }
override suspend fun resolveRaw(network: Network, query: ByteArray): ByteArray {
return suspendCancellableCoroutine { cont ->
val signal = CancellationSignal()
cont.invokeOnCancellation { signal.cancel() }
DnsResolver.getInstance().rawQuery(network, query, DnsResolver.FLAG_NO_RETRY, this,
signal, object : DnsResolver.Callback<ByteArray> {
override fun onAnswer(answer: ByteArray, rcode: Int) = cont.resume(answer)
override fun onError(error: DnsResolver.DnsException) = cont.resumeWithException(IOException(error))
})
}
}
} }
} }
...@@ -30,27 +30,23 @@ import android.os.Build ...@@ -30,27 +30,23 @@ import android.os.Build
import android.os.ParcelFileDescriptor import android.os.ParcelFileDescriptor
import android.system.ErrnoException import android.system.ErrnoException
import android.system.OsConstants import android.system.OsConstants
import android.util.Log
import com.crashlytics.android.Crashlytics
import com.github.shadowsocks.Core import com.github.shadowsocks.Core
import com.github.shadowsocks.VpnRequestActivity import com.github.shadowsocks.VpnRequestActivity
import com.github.shadowsocks.acl.Acl import com.github.shadowsocks.acl.Acl
import com.github.shadowsocks.core.R import com.github.shadowsocks.core.R
import com.github.shadowsocks.net.ConcurrentLocalSocketListener import com.github.shadowsocks.net.*
import com.github.shadowsocks.net.DefaultNetworkListener
import com.github.shadowsocks.net.HostsFile
import com.github.shadowsocks.net.Subnet
import com.github.shadowsocks.preference.DataStore import com.github.shadowsocks.preference.DataStore
import com.github.shadowsocks.utils.Key import com.github.shadowsocks.utils.Key
import com.github.shadowsocks.utils.closeQuietly import com.github.shadowsocks.utils.closeQuietly
import com.github.shadowsocks.utils.int import com.github.shadowsocks.utils.int
import com.github.shadowsocks.utils.printLog import com.github.shadowsocks.utils.printLog
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.*
import kotlinx.coroutines.delay import org.xbill.DNS.Message
import kotlinx.coroutines.launch import org.xbill.DNS.Rcode
import java.io.File import java.io.*
import java.io.FileDescriptor
import java.io.IOException
import java.net.URL import java.net.URL
import java.util.*
import android.net.VpnService as BaseVpnService import android.net.VpnService as BaseVpnService
class VpnService : BaseVpnService(), LocalDnsService.Interface { class VpnService : BaseVpnService(), LocalDnsService.Interface {
...@@ -91,6 +87,45 @@ class VpnService : BaseVpnService(), LocalDnsService.Interface { ...@@ -91,6 +87,45 @@ class VpnService : BaseVpnService(), LocalDnsService.Interface {
} }
} }
private inner class LocalDnsWorker : ConcurrentLocalSocketListener("LocalDnsThread",
File(Core.deviceStorage.noBackupFilesDir, "local_dns_path")), CoroutineScope {
override fun acceptInternal(socket: LocalSocket) = error("big no no")
override fun accept(socket: LocalSocket) {
launch {
socket.use {
val input = DataInputStream(socket.inputStream)
while (true) {
val length = try {
input.readUnsignedShort()
} catch (_: EOFException) {
break
}
val query = ByteArray(length)
input.read(query)
try {
DnsResolverCompat.resolveRaw(underlyingNetwork ?: throw IOException("no network"), query)
} catch (e: Exception) {
when (e) {
is TimeoutCancellationException -> Crashlytics.log(Log.WARN, name, "Resolving timed out")
is CancellationException -> { } // ignore
is IOException -> Crashlytics.log(Log.WARN, name, e.message)
else -> printLog(e)
}
try {
LocalDnsServer.prepareDnsResponse(Message(query)).apply {
header.rcode = Rcode.SERVFAIL
}.toWire()
} catch (e: Exception) {
printLog(e)
null
}
}?.let { response -> socket.outputStream.write(response) }
}
}
}
}
}
inner class NullConnectionException : NullPointerException(), BaseService.ExpectedException { inner class NullConnectionException : NullPointerException(), BaseService.ExpectedException {
override fun getLocalizedMessage() = getString(R.string.reboot_required) override fun getLocalizedMessage() = getString(R.string.reboot_required)
} }
...@@ -102,6 +137,7 @@ class VpnService : BaseVpnService(), LocalDnsService.Interface { ...@@ -102,6 +137,7 @@ class VpnService : BaseVpnService(), LocalDnsService.Interface {
private var conn: ParcelFileDescriptor? = null private var conn: ParcelFileDescriptor? = null
private var worker: ProtectWorker? = null private var worker: ProtectWorker? = null
private var localDns: LocalDnsWorker? = null
private var active = false private var active = false
private var metered = false private var metered = false
private var underlyingNetwork: Network? = null private var underlyingNetwork: Network? = null
...@@ -126,6 +162,8 @@ class VpnService : BaseVpnService(), LocalDnsService.Interface { ...@@ -126,6 +162,8 @@ class VpnService : BaseVpnService(), LocalDnsService.Interface {
scope.launch { DefaultNetworkListener.stop(this) } scope.launch { DefaultNetworkListener.stop(this) }
worker?.shutdown(scope) worker?.shutdown(scope)
worker = null worker = null
localDns?.shutdown(scope)
localDns = null
conn?.close() conn?.close()
conn = null conn = null
} }
...@@ -147,6 +185,7 @@ class VpnService : BaseVpnService(), LocalDnsService.Interface { ...@@ -147,6 +185,7 @@ class VpnService : BaseVpnService(), LocalDnsService.Interface {
override suspend fun startProcesses(hosts: HostsFile) { override suspend fun startProcesses(hosts: HostsFile) {
worker = ProtectWorker().apply { start() } worker = ProtectWorker().apply { start() }
localDns = LocalDnsWorker().apply { start() }
super.startProcesses(hosts) super.startProcesses(hosts)
sendFd(startVpn()) sendFd(startVpn())
} }
......
...@@ -62,21 +62,20 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd ...@@ -62,21 +62,20 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd
private const val TTL = 120L private const val TTL = 120L
private const val UDP_PACKET_SIZE = 512 private const val UDP_PACKET_SIZE = 512
private fun prepareDnsResponse(request: Message) = Message(request.header.id).apply { fun prepareDnsResponse(request: Message) = Message(request.header.id).apply {
header.setFlag(Flags.QR.toInt()) // this is a response header.setFlag(Flags.QR.toInt()) // this is a response
if (request.header.getFlag(Flags.RD.toInt())) header.setFlag(Flags.RD.toInt()) if (request.header.getFlag(Flags.RD.toInt())) header.setFlag(Flags.RD.toInt())
request.question?.also { addRecord(it, Section.QUESTION) } request.question?.also { addRecord(it, Section.QUESTION) }
} }
private fun cookDnsResponse(request: Message, results: Iterable<InetAddress>) = fun cookDnsResponse(request: Message, results: Iterable<InetAddress>) = prepareDnsResponse(request).apply {
ByteBuffer.wrap(prepareDnsResponse(request).apply {
header.setFlag(Flags.RA.toInt()) // recursion available header.setFlag(Flags.RA.toInt()) // recursion available
for (address in results) addRecord(when (address) { for (address in results) addRecord(when (address) {
is Inet4Address -> ARecord(question.name, DClass.IN, TTL, address) is Inet4Address -> ARecord(question.name, DClass.IN, TTL, address)
is Inet6Address -> AAAARecord(question.name, DClass.IN, TTL, address) is Inet6Address -> AAAARecord(question.name, DClass.IN, TTL, address)
else -> error("Unsupported address $address") else -> error("Unsupported address $address")
}, Section.ANSWER) }, Section.ANSWER)
}.toWire()) }.toWire()
} }
private val monitor = ChannelMonitor() private val monitor = ChannelMonitor()
...@@ -127,9 +126,9 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd ...@@ -127,9 +126,9 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd
val hostsResults = hosts.resolve(host) val hostsResults = hosts.resolve(host)
if (hostsResults.isNotEmpty()) { if (hostsResults.isNotEmpty()) {
remote.cancel() remote.cancel()
return@supervisorScope cookDnsResponse(request, hostsResults.run { return@supervisorScope ByteBuffer.wrap(cookDnsResponse(request, hostsResults.run {
if (isIpv6) filterIsInstance<Inet6Address>() else filterIsInstance<Inet4Address>() if (isIpv6) filterIsInstance<Inet6Address>() else filterIsInstance<Inet4Address>()
}) }))
} }
val acl = acl?.await() ?: return@supervisorScope remote.await() val acl = acl?.await() ?: return@supervisorScope remote.await()
val useLocal = when (acl.shouldBypass(host)) { val useLocal = when (acl.shouldBypass(host)) {
...@@ -147,17 +146,17 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd ...@@ -147,17 +146,17 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd
} }
if (isIpv6) { if (isIpv6) {
val filtered = localResults.filterIsInstance<Inet6Address>() val filtered = localResults.filterIsInstance<Inet6Address>()
if (useLocal) return@supervisorScope cookDnsResponse(request, filtered) if (useLocal) return@supervisorScope ByteBuffer.wrap(cookDnsResponse(request, filtered))
if (filtered.any { acl.shouldBypassIpv6(it.address) }) { if (filtered.any { acl.shouldBypassIpv6(it.address) }) {
remote.cancel() remote.cancel()
cookDnsResponse(request, filtered) ByteBuffer.wrap(cookDnsResponse(request, filtered))
} else remote.await() } else remote.await()
} else { } else {
val filtered = localResults.filterIsInstance<Inet4Address>() val filtered = localResults.filterIsInstance<Inet4Address>()
if (useLocal) return@supervisorScope cookDnsResponse(request, filtered) if (useLocal) return@supervisorScope ByteBuffer.wrap(cookDnsResponse(request, filtered))
if (filtered.any { acl.shouldBypassIpv4(it.address) }) { if (filtered.any { acl.shouldBypassIpv4(it.address) }) {
remote.cancel() remote.cancel()
cookDnsResponse(request, filtered) ByteBuffer.wrap(cookDnsResponse(request, filtered))
} else remote.await() } else remote.await()
} }
} catch (e: Exception) { } catch (e: Exception) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment