Commit 91632d10 authored by Mygod's avatar Mygod

Move local dns handler to BaseService

parent 445461f3
......@@ -46,7 +46,6 @@ import java.io.File
import java.io.IOException
import java.net.URL
import java.net.UnknownHostException
import java.util.*
/**
* This object uses WeakMap to simulate the effects of multi-inheritance.
......@@ -74,6 +73,7 @@ object BaseService {
var processes: GuardedProcessPool? = null
var proxy: ProxyInstance? = null
var udpFallback: ProxyInstance? = null
var localDns: LocalDnsWorker? = null
var notification: ServiceNotification? = null
val closeReceiver = broadcastReceiver { _, intent ->
......@@ -247,6 +247,7 @@ object BaseService {
File(Core.deviceStorage.noBackupFilesDir, "stat_udp"),
File(configRoot, CONFIG_FILE_UDP),
"-u")
data.localDns = LocalDnsWorker(this::rawResolver).apply { start() }
}
fun startRunner() {
......@@ -260,6 +261,8 @@ object BaseService {
close(scope)
data.processes = null
}
data.localDns?.shutdown(scope)
data.localDns = null
}
fun stopRunner(restart: Boolean = false, msg: String? = null) {
......@@ -309,6 +312,7 @@ object BaseService {
suspend fun preInit() { }
suspend fun getActiveNetwork() = if (Build.VERSION.SDK_INT >= 23) Core.connectivity.activeNetwork else null
suspend fun resolver(host: String) = DnsResolverCompat.resolveOnActiveNetwork(host)
suspend fun rawResolver(query: ByteArray) = DnsResolverCompat.resolveRawOnActiveNetwork(query)
suspend fun openConnection(url: URL) = url.openConnection()
fun onStartCommand(intent: Intent?, flags: Int, startId: Int): Int {
......
......@@ -101,6 +101,7 @@ sealed class DnsResolverCompat {
override suspend fun resolve(network: Network, host: String) = instance.resolve(network, host)
override suspend fun resolveOnActiveNetwork(host: String) = instance.resolveOnActiveNetwork(host)
override suspend fun resolveRaw(network: Network, query: ByteArray) = instance.resolveRaw(network, query)
override suspend fun resolveRawOnActiveNetwork(query: ByteArray) = instance.resolveRawOnActiveNetwork(query)
}
@Throws(IOException::class)
......@@ -110,6 +111,7 @@ sealed class DnsResolverCompat {
abstract suspend fun resolve(network: Network, host: String): Array<InetAddress>
abstract suspend fun resolveOnActiveNetwork(host: String): Array<InetAddress>
abstract suspend fun resolveRaw(network: Network, query: ByteArray): ByteArray
abstract suspend fun resolveRawOnActiveNetwork(query: ByteArray): ByteArray
@SuppressLint("PrivateApi")
private open class DnsResolverCompat21 : DnsResolverCompat() {
......@@ -147,7 +149,8 @@ sealed class DnsResolverCompat {
override suspend fun resolveOnActiveNetwork(host: String) =
GlobalScope.async(unboundedIO) { InetAddress.getAllByName(host) }.await()
override suspend fun resolveRaw(network: Network, query: ByteArray): ByteArray {
private suspend fun resolveRaw(query: ByteArray,
hostResolver: suspend (String) -> Array<InetAddress>): ByteArray {
val request = try {
Message(query)
} catch (e: IOException) {
......@@ -164,10 +167,14 @@ sealed class DnsResolverCompat {
else -> throw UnsupportedOperationException("Unsupported query type $type")
}
val host = question.name.canonicalize().toString(true)
return LocalDnsServer.cookDnsResponse(request, resolve(network, host).asIterable().run {
return LocalDnsServer.cookDnsResponse(request, hostResolver(host).asIterable().run {
if (isIpv6) filterIsInstance<Inet6Address>() else filterIsInstance<Inet4Address>()
})
}
override suspend fun resolveRaw(network: Network, query: ByteArray) =
resolveRaw(query) { resolve(network, it) }
override suspend fun resolveRawOnActiveNetwork(query: ByteArray) =
resolveRaw(query, this::resolveOnActiveNetwork)
}
@TargetApi(23)
......@@ -184,6 +191,8 @@ sealed class DnsResolverCompat {
override fun bindSocket(network: Network, socket: FileDescriptor) = network.bindSocket(socket)
private val activeNetwork get() = Core.connectivity.activeNetwork ?: throw IOException("no network")
override suspend fun resolve(network: Network, host: String): Array<InetAddress> {
return suspendCancellableCoroutine { cont ->
val signal = CancellationSignal()
......@@ -197,10 +206,7 @@ sealed class DnsResolverCompat {
})
}
}
override suspend fun resolveOnActiveNetwork(host: String): Array<InetAddress> {
return resolve(Core.connectivity.activeNetwork ?: return emptyArray(), host)
}
override suspend fun resolveOnActiveNetwork(host: String) = resolve(activeNetwork, host)
override suspend fun resolveRaw(network: Network, query: ByteArray): ByteArray {
return suspendCancellableCoroutine { cont ->
......@@ -213,5 +219,6 @@ sealed class DnsResolverCompat {
})
}
}
override suspend fun resolveRawOnActiveNetwork(query: ByteArray) = resolveRaw(activeNetwork, query)
}
}
package com.github.shadowsocks.bg
import android.net.LocalSocket
import android.util.Log
import com.crashlytics.android.Crashlytics
import com.github.shadowsocks.Core
import com.github.shadowsocks.net.ConcurrentLocalSocketListener
import com.github.shadowsocks.net.LocalDnsServer
import com.github.shadowsocks.utils.printLog
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.TimeoutCancellationException
import kotlinx.coroutines.launch
import org.xbill.DNS.Message
import org.xbill.DNS.Rcode
import java.io.DataInputStream
import java.io.DataOutputStream
import java.io.File
import java.io.IOException
class LocalDnsWorker(private val resolver: suspend (ByteArray) -> ByteArray) : ConcurrentLocalSocketListener(
"LocalDnsThread", File(Core.deviceStorage.noBackupFilesDir, "local_dns_path")), CoroutineScope {
override fun acceptInternal(socket: LocalSocket) = error("big no no")
override fun accept(socket: LocalSocket) {
launch {
socket.use {
val input = DataInputStream(socket.inputStream)
val query = ByteArray(input.readUnsignedShort())
input.read(query)
try {
resolver(query)
} catch (e: Exception) {
when (e) {
is TimeoutCancellationException -> Crashlytics.log(Log.WARN, name, "Resolving timed out")
is CancellationException -> { } // ignore
is IOException -> Crashlytics.log(Log.WARN, name, e.message)
else -> printLog(e)
}
try {
LocalDnsServer.prepareDnsResponse(Message(query)).apply {
header.rcode = Rcode.SERVFAIL
}.toWire()
} catch (_: IOException) {
byteArrayOf() // return empty if cannot parse packet
}
}?.let { response ->
val output = DataOutputStream(socket.outputStream)
output.writeShort(response.size)
output.write(response)
}
}
}
}
}
......@@ -30,8 +30,6 @@ import android.os.Build
import android.os.ParcelFileDescriptor
import android.system.ErrnoException
import android.system.OsConstants
import android.util.Log
import com.crashlytics.android.Crashlytics
import com.github.shadowsocks.Core
import com.github.shadowsocks.VpnRequestActivity
import com.github.shadowsocks.acl.Acl
......@@ -43,8 +41,6 @@ import com.github.shadowsocks.utils.closeQuietly
import com.github.shadowsocks.utils.int
import com.github.shadowsocks.utils.printLog
import kotlinx.coroutines.*
import org.xbill.DNS.Message
import org.xbill.DNS.Rcode
import java.io.*
import java.net.URL
import android.net.VpnService as BaseVpnService
......@@ -87,37 +83,6 @@ class VpnService : BaseVpnService(), BaseService.Interface {
}
}
private inner class LocalDnsWorker : ConcurrentLocalSocketListener("LocalDnsThread",
File(Core.deviceStorage.noBackupFilesDir, "local_dns_path")), CoroutineScope {
override fun acceptInternal(socket: LocalSocket) = error("big no no")
override fun accept(socket: LocalSocket) {
launch {
socket.use {
val input = DataInputStream(socket.inputStream)
val query = ByteArray(input.readUnsignedShort())
input.read(query)
try {
DnsResolverCompat.resolveRaw(underlyingNetwork ?: throw IOException("no network"), query)
} catch (e: Exception) {
when (e) {
is TimeoutCancellationException -> Crashlytics.log(Log.WARN, name, "Resolving timed out")
is CancellationException -> { } // ignore
is IOException -> Crashlytics.log(Log.WARN, name, e.message)
else -> printLog(e)
}
LocalDnsServer.prepareDnsResponse(Message(query)).apply {
header.rcode = Rcode.SERVFAIL
}.toWire()
}?.let { response ->
val output = DataOutputStream(socket.outputStream)
output.writeShort(response.size)
output.write(response)
}
}
}
}
}
inner class NullConnectionException : NullPointerException(), BaseService.ExpectedException {
override fun getLocalizedMessage() = getString(R.string.reboot_required)
}
......@@ -129,7 +94,6 @@ class VpnService : BaseVpnService(), BaseService.Interface {
private var conn: ParcelFileDescriptor? = null
private var worker: ProtectWorker? = null
private var localDns: LocalDnsWorker? = null
private var active = false
private var metered = false
private var underlyingNetwork: Network? = null
......@@ -154,8 +118,6 @@ class VpnService : BaseVpnService(), BaseService.Interface {
scope.launch { DefaultNetworkListener.stop(this) }
worker?.shutdown(scope)
worker = null
localDns?.shutdown(scope)
localDns = null
conn?.close()
conn = null
}
......@@ -173,11 +135,14 @@ class VpnService : BaseVpnService(), BaseService.Interface {
override suspend fun preInit() = DefaultNetworkListener.start(this) { underlyingNetwork = it }
override suspend fun getActiveNetwork() = DefaultNetworkListener.get()
override suspend fun resolver(host: String) = DnsResolverCompat.resolve(DefaultNetworkListener.get(), host)
override suspend fun rawResolver(query: ByteArray) =
// no need to listen for network here as this is only used for forwarding local DNS queries.
// retries should be attempted by client.
DnsResolverCompat.resolveRaw(underlyingNetwork ?: throw IOException("no network"), query)
override suspend fun openConnection(url: URL) = DefaultNetworkListener.get().openConnection(url)
override suspend fun startProcesses(hosts: HostsFile) {
worker = ProtectWorker().apply { start() }
localDns = LocalDnsWorker().apply { start() }
super.startProcesses(hosts)
sendFd(startVpn())
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment