Commit 347d00fb authored by Mygod's avatar Mygod Committed by Max Lv

Add draft impl for local_dns_path

parent ad868f54
......@@ -31,13 +31,19 @@ import android.system.ErrnoException
import android.system.Os
import android.system.OsConstants
import com.github.shadowsocks.Core
import com.github.shadowsocks.net.LocalDnsServer
import com.github.shadowsocks.utils.closeQuietly
import com.github.shadowsocks.utils.int
import com.github.shadowsocks.utils.parseNumericAddress
import com.github.shadowsocks.utils.printLog
import kotlinx.coroutines.*
import org.xbill.DNS.Message
import org.xbill.DNS.Opcode
import org.xbill.DNS.Type
import java.io.FileDescriptor
import java.io.IOException
import java.net.Inet4Address
import java.net.Inet6Address
import java.net.InetAddress
import java.util.concurrent.Executor
import java.util.concurrent.Executors
......@@ -94,6 +100,7 @@ sealed class DnsResolverCompat {
override fun bindSocket(network: Network, socket: FileDescriptor) = instance.bindSocket(network, socket)
override suspend fun resolve(network: Network, host: String) = instance.resolve(network, host)
override suspend fun resolveOnActiveNetwork(host: String) = instance.resolveOnActiveNetwork(host)
override suspend fun resolveRaw(network: Network, query: ByteArray) = instance.resolveRaw(network, query)
}
@Throws(IOException::class)
......@@ -102,6 +109,7 @@ sealed class DnsResolverCompat {
Os.connect(fd, address, port)
abstract suspend fun resolve(network: Network, host: String): Array<InetAddress>
abstract suspend fun resolveOnActiveNetwork(host: String): Array<InetAddress>
abstract suspend fun resolveRaw(network: Network, query: ByteArray): ByteArray
@SuppressLint("PrivateApi")
private open class DnsResolverCompat21 : DnsResolverCompat() {
......@@ -138,6 +146,28 @@ sealed class DnsResolverCompat {
GlobalScope.async(unboundedIO) { network.getAllByName(host) }.await()
override suspend fun resolveOnActiveNetwork(host: String) =
GlobalScope.async(unboundedIO) { InetAddress.getAllByName(host) }.await()
override suspend fun resolveRaw(network: Network, query: ByteArray): ByteArray {
val request = try {
Message(query)
} catch (e: IOException) {
throw UnsupportedOperationException(e) // unrecognized packet
}
when (val opcode = request.header.opcode) {
Opcode.QUERY -> { }
else -> throw UnsupportedOperationException("Unsupported opcode $opcode")
}
val question = request.question
val isIpv6 = when (val type = question?.type) {
Type.A -> false
Type.AAAA -> true
else -> throw UnsupportedOperationException("Unsupported query type $type")
}
val host = question.name.canonicalize().toString(true)
return LocalDnsServer.cookDnsResponse(request, resolve(network, host).asIterable().run {
if (isIpv6) filterIsInstance<Inet6Address>() else filterIsInstance<Inet4Address>()
})
}
}
@TargetApi(23)
......@@ -171,5 +201,17 @@ sealed class DnsResolverCompat {
override suspend fun resolveOnActiveNetwork(host: String): Array<InetAddress> {
return resolve(Core.connectivity.activeNetwork ?: return emptyArray(), host)
}
override suspend fun resolveRaw(network: Network, query: ByteArray): ByteArray {
return suspendCancellableCoroutine { cont ->
val signal = CancellationSignal()
cont.invokeOnCancellation { signal.cancel() }
DnsResolver.getInstance().rawQuery(network, query, DnsResolver.FLAG_NO_RETRY, this,
signal, object : DnsResolver.Callback<ByteArray> {
override fun onAnswer(answer: ByteArray, rcode: Int) = cont.resume(answer)
override fun onError(error: DnsResolver.DnsException) = cont.resumeWithException(IOException(error))
})
}
}
}
}
......@@ -30,27 +30,23 @@ import android.os.Build
import android.os.ParcelFileDescriptor
import android.system.ErrnoException
import android.system.OsConstants
import android.util.Log
import com.crashlytics.android.Crashlytics
import com.github.shadowsocks.Core
import com.github.shadowsocks.VpnRequestActivity
import com.github.shadowsocks.acl.Acl
import com.github.shadowsocks.core.R
import com.github.shadowsocks.net.ConcurrentLocalSocketListener
import com.github.shadowsocks.net.DefaultNetworkListener
import com.github.shadowsocks.net.HostsFile
import com.github.shadowsocks.net.Subnet
import com.github.shadowsocks.net.*
import com.github.shadowsocks.preference.DataStore
import com.github.shadowsocks.utils.Key
import com.github.shadowsocks.utils.closeQuietly
import com.github.shadowsocks.utils.int
import com.github.shadowsocks.utils.printLog
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import java.io.File
import java.io.FileDescriptor
import java.io.IOException
import kotlinx.coroutines.*
import org.xbill.DNS.Message
import org.xbill.DNS.Rcode
import java.io.*
import java.net.URL
import java.util.*
import android.net.VpnService as BaseVpnService
class VpnService : BaseVpnService(), LocalDnsService.Interface {
......@@ -91,6 +87,45 @@ class VpnService : BaseVpnService(), LocalDnsService.Interface {
}
}
private inner class LocalDnsWorker : ConcurrentLocalSocketListener("LocalDnsThread",
File(Core.deviceStorage.noBackupFilesDir, "local_dns_path")), CoroutineScope {
override fun acceptInternal(socket: LocalSocket) = error("big no no")
override fun accept(socket: LocalSocket) {
launch {
socket.use {
val input = DataInputStream(socket.inputStream)
while (true) {
val length = try {
input.readUnsignedShort()
} catch (_: EOFException) {
break
}
val query = ByteArray(length)
input.read(query)
try {
DnsResolverCompat.resolveRaw(underlyingNetwork ?: throw IOException("no network"), query)
} catch (e: Exception) {
when (e) {
is TimeoutCancellationException -> Crashlytics.log(Log.WARN, name, "Resolving timed out")
is CancellationException -> { } // ignore
is IOException -> Crashlytics.log(Log.WARN, name, e.message)
else -> printLog(e)
}
try {
LocalDnsServer.prepareDnsResponse(Message(query)).apply {
header.rcode = Rcode.SERVFAIL
}.toWire()
} catch (e: Exception) {
printLog(e)
null
}
}?.let { response -> socket.outputStream.write(response) }
}
}
}
}
}
inner class NullConnectionException : NullPointerException(), BaseService.ExpectedException {
override fun getLocalizedMessage() = getString(R.string.reboot_required)
}
......@@ -102,6 +137,7 @@ class VpnService : BaseVpnService(), LocalDnsService.Interface {
private var conn: ParcelFileDescriptor? = null
private var worker: ProtectWorker? = null
private var localDns: LocalDnsWorker? = null
private var active = false
private var metered = false
private var underlyingNetwork: Network? = null
......@@ -126,6 +162,8 @@ class VpnService : BaseVpnService(), LocalDnsService.Interface {
scope.launch { DefaultNetworkListener.stop(this) }
worker?.shutdown(scope)
worker = null
localDns?.shutdown(scope)
localDns = null
conn?.close()
conn = null
}
......@@ -147,6 +185,7 @@ class VpnService : BaseVpnService(), LocalDnsService.Interface {
override suspend fun startProcesses(hosts: HostsFile) {
worker = ProtectWorker().apply { start() }
localDns = LocalDnsWorker().apply { start() }
super.startProcesses(hosts)
sendFd(startVpn())
}
......
......@@ -62,21 +62,20 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd
private const val TTL = 120L
private const val UDP_PACKET_SIZE = 512
private fun prepareDnsResponse(request: Message) = Message(request.header.id).apply {
fun prepareDnsResponse(request: Message) = Message(request.header.id).apply {
header.setFlag(Flags.QR.toInt()) // this is a response
if (request.header.getFlag(Flags.RD.toInt())) header.setFlag(Flags.RD.toInt())
request.question?.also { addRecord(it, Section.QUESTION) }
}
private fun cookDnsResponse(request: Message, results: Iterable<InetAddress>) =
ByteBuffer.wrap(prepareDnsResponse(request).apply {
fun cookDnsResponse(request: Message, results: Iterable<InetAddress>) = prepareDnsResponse(request).apply {
header.setFlag(Flags.RA.toInt()) // recursion available
for (address in results) addRecord(when (address) {
is Inet4Address -> ARecord(question.name, DClass.IN, TTL, address)
is Inet6Address -> AAAARecord(question.name, DClass.IN, TTL, address)
else -> error("Unsupported address $address")
}, Section.ANSWER)
}.toWire())
}.toWire()
}
private val monitor = ChannelMonitor()
......@@ -127,9 +126,9 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd
val hostsResults = hosts.resolve(host)
if (hostsResults.isNotEmpty()) {
remote.cancel()
return@supervisorScope cookDnsResponse(request, hostsResults.run {
return@supervisorScope ByteBuffer.wrap(cookDnsResponse(request, hostsResults.run {
if (isIpv6) filterIsInstance<Inet6Address>() else filterIsInstance<Inet4Address>()
})
}))
}
val acl = acl?.await() ?: return@supervisorScope remote.await()
val useLocal = when (acl.shouldBypass(host)) {
......@@ -147,17 +146,17 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd
}
if (isIpv6) {
val filtered = localResults.filterIsInstance<Inet6Address>()
if (useLocal) return@supervisorScope cookDnsResponse(request, filtered)
if (useLocal) return@supervisorScope ByteBuffer.wrap(cookDnsResponse(request, filtered))
if (filtered.any { acl.shouldBypassIpv6(it.address) }) {
remote.cancel()
cookDnsResponse(request, filtered)
ByteBuffer.wrap(cookDnsResponse(request, filtered))
} else remote.await()
} else {
val filtered = localResults.filterIsInstance<Inet4Address>()
if (useLocal) return@supervisorScope cookDnsResponse(request, filtered)
if (useLocal) return@supervisorScope ByteBuffer.wrap(cookDnsResponse(request, filtered))
if (filtered.any { acl.shouldBypassIpv4(it.address) }) {
remote.cancel()
cookDnsResponse(request, filtered)
ByteBuffer.wrap(cookDnsResponse(request, filtered))
} else remote.await()
}
} catch (e: Exception) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment