Commit 91632d10 authored by Mygod's avatar Mygod

Move local dns handler to BaseService

parent 445461f3
...@@ -46,7 +46,6 @@ import java.io.File ...@@ -46,7 +46,6 @@ import java.io.File
import java.io.IOException import java.io.IOException
import java.net.URL import java.net.URL
import java.net.UnknownHostException import java.net.UnknownHostException
import java.util.*
/** /**
* This object uses WeakMap to simulate the effects of multi-inheritance. * This object uses WeakMap to simulate the effects of multi-inheritance.
...@@ -74,6 +73,7 @@ object BaseService { ...@@ -74,6 +73,7 @@ object BaseService {
var processes: GuardedProcessPool? = null var processes: GuardedProcessPool? = null
var proxy: ProxyInstance? = null var proxy: ProxyInstance? = null
var udpFallback: ProxyInstance? = null var udpFallback: ProxyInstance? = null
var localDns: LocalDnsWorker? = null
var notification: ServiceNotification? = null var notification: ServiceNotification? = null
val closeReceiver = broadcastReceiver { _, intent -> val closeReceiver = broadcastReceiver { _, intent ->
...@@ -247,6 +247,7 @@ object BaseService { ...@@ -247,6 +247,7 @@ object BaseService {
File(Core.deviceStorage.noBackupFilesDir, "stat_udp"), File(Core.deviceStorage.noBackupFilesDir, "stat_udp"),
File(configRoot, CONFIG_FILE_UDP), File(configRoot, CONFIG_FILE_UDP),
"-u") "-u")
data.localDns = LocalDnsWorker(this::rawResolver).apply { start() }
} }
fun startRunner() { fun startRunner() {
...@@ -260,6 +261,8 @@ object BaseService { ...@@ -260,6 +261,8 @@ object BaseService {
close(scope) close(scope)
data.processes = null data.processes = null
} }
data.localDns?.shutdown(scope)
data.localDns = null
} }
fun stopRunner(restart: Boolean = false, msg: String? = null) { fun stopRunner(restart: Boolean = false, msg: String? = null) {
...@@ -309,6 +312,7 @@ object BaseService { ...@@ -309,6 +312,7 @@ object BaseService {
suspend fun preInit() { } suspend fun preInit() { }
suspend fun getActiveNetwork() = if (Build.VERSION.SDK_INT >= 23) Core.connectivity.activeNetwork else null suspend fun getActiveNetwork() = if (Build.VERSION.SDK_INT >= 23) Core.connectivity.activeNetwork else null
suspend fun resolver(host: String) = DnsResolverCompat.resolveOnActiveNetwork(host) suspend fun resolver(host: String) = DnsResolverCompat.resolveOnActiveNetwork(host)
suspend fun rawResolver(query: ByteArray) = DnsResolverCompat.resolveRawOnActiveNetwork(query)
suspend fun openConnection(url: URL) = url.openConnection() suspend fun openConnection(url: URL) = url.openConnection()
fun onStartCommand(intent: Intent?, flags: Int, startId: Int): Int { fun onStartCommand(intent: Intent?, flags: Int, startId: Int): Int {
......
...@@ -101,6 +101,7 @@ sealed class DnsResolverCompat { ...@@ -101,6 +101,7 @@ sealed class DnsResolverCompat {
override suspend fun resolve(network: Network, host: String) = instance.resolve(network, host) override suspend fun resolve(network: Network, host: String) = instance.resolve(network, host)
override suspend fun resolveOnActiveNetwork(host: String) = instance.resolveOnActiveNetwork(host) override suspend fun resolveOnActiveNetwork(host: String) = instance.resolveOnActiveNetwork(host)
override suspend fun resolveRaw(network: Network, query: ByteArray) = instance.resolveRaw(network, query) override suspend fun resolveRaw(network: Network, query: ByteArray) = instance.resolveRaw(network, query)
override suspend fun resolveRawOnActiveNetwork(query: ByteArray) = instance.resolveRawOnActiveNetwork(query)
} }
@Throws(IOException::class) @Throws(IOException::class)
...@@ -110,6 +111,7 @@ sealed class DnsResolverCompat { ...@@ -110,6 +111,7 @@ sealed class DnsResolverCompat {
abstract suspend fun resolve(network: Network, host: String): Array<InetAddress> abstract suspend fun resolve(network: Network, host: String): Array<InetAddress>
abstract suspend fun resolveOnActiveNetwork(host: String): Array<InetAddress> abstract suspend fun resolveOnActiveNetwork(host: String): Array<InetAddress>
abstract suspend fun resolveRaw(network: Network, query: ByteArray): ByteArray abstract suspend fun resolveRaw(network: Network, query: ByteArray): ByteArray
abstract suspend fun resolveRawOnActiveNetwork(query: ByteArray): ByteArray
@SuppressLint("PrivateApi") @SuppressLint("PrivateApi")
private open class DnsResolverCompat21 : DnsResolverCompat() { private open class DnsResolverCompat21 : DnsResolverCompat() {
...@@ -147,7 +149,8 @@ sealed class DnsResolverCompat { ...@@ -147,7 +149,8 @@ sealed class DnsResolverCompat {
override suspend fun resolveOnActiveNetwork(host: String) = override suspend fun resolveOnActiveNetwork(host: String) =
GlobalScope.async(unboundedIO) { InetAddress.getAllByName(host) }.await() GlobalScope.async(unboundedIO) { InetAddress.getAllByName(host) }.await()
override suspend fun resolveRaw(network: Network, query: ByteArray): ByteArray { private suspend fun resolveRaw(query: ByteArray,
hostResolver: suspend (String) -> Array<InetAddress>): ByteArray {
val request = try { val request = try {
Message(query) Message(query)
} catch (e: IOException) { } catch (e: IOException) {
...@@ -164,10 +167,14 @@ sealed class DnsResolverCompat { ...@@ -164,10 +167,14 @@ sealed class DnsResolverCompat {
else -> throw UnsupportedOperationException("Unsupported query type $type") else -> throw UnsupportedOperationException("Unsupported query type $type")
} }
val host = question.name.canonicalize().toString(true) val host = question.name.canonicalize().toString(true)
return LocalDnsServer.cookDnsResponse(request, resolve(network, host).asIterable().run { return LocalDnsServer.cookDnsResponse(request, hostResolver(host).asIterable().run {
if (isIpv6) filterIsInstance<Inet6Address>() else filterIsInstance<Inet4Address>() if (isIpv6) filterIsInstance<Inet6Address>() else filterIsInstance<Inet4Address>()
}) })
} }
override suspend fun resolveRaw(network: Network, query: ByteArray) =
resolveRaw(query) { resolve(network, it) }
override suspend fun resolveRawOnActiveNetwork(query: ByteArray) =
resolveRaw(query, this::resolveOnActiveNetwork)
} }
@TargetApi(23) @TargetApi(23)
...@@ -184,6 +191,8 @@ sealed class DnsResolverCompat { ...@@ -184,6 +191,8 @@ sealed class DnsResolverCompat {
override fun bindSocket(network: Network, socket: FileDescriptor) = network.bindSocket(socket) override fun bindSocket(network: Network, socket: FileDescriptor) = network.bindSocket(socket)
private val activeNetwork get() = Core.connectivity.activeNetwork ?: throw IOException("no network")
override suspend fun resolve(network: Network, host: String): Array<InetAddress> { override suspend fun resolve(network: Network, host: String): Array<InetAddress> {
return suspendCancellableCoroutine { cont -> return suspendCancellableCoroutine { cont ->
val signal = CancellationSignal() val signal = CancellationSignal()
...@@ -197,10 +206,7 @@ sealed class DnsResolverCompat { ...@@ -197,10 +206,7 @@ sealed class DnsResolverCompat {
}) })
} }
} }
override suspend fun resolveOnActiveNetwork(host: String) = resolve(activeNetwork, host)
override suspend fun resolveOnActiveNetwork(host: String): Array<InetAddress> {
return resolve(Core.connectivity.activeNetwork ?: return emptyArray(), host)
}
override suspend fun resolveRaw(network: Network, query: ByteArray): ByteArray { override suspend fun resolveRaw(network: Network, query: ByteArray): ByteArray {
return suspendCancellableCoroutine { cont -> return suspendCancellableCoroutine { cont ->
...@@ -213,5 +219,6 @@ sealed class DnsResolverCompat { ...@@ -213,5 +219,6 @@ sealed class DnsResolverCompat {
}) })
} }
} }
override suspend fun resolveRawOnActiveNetwork(query: ByteArray) = resolveRaw(activeNetwork, query)
} }
} }
package com.github.shadowsocks.bg
import android.net.LocalSocket
import android.util.Log
import com.crashlytics.android.Crashlytics
import com.github.shadowsocks.Core
import com.github.shadowsocks.net.ConcurrentLocalSocketListener
import com.github.shadowsocks.net.LocalDnsServer
import com.github.shadowsocks.utils.printLog
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.TimeoutCancellationException
import kotlinx.coroutines.launch
import org.xbill.DNS.Message
import org.xbill.DNS.Rcode
import java.io.DataInputStream
import java.io.DataOutputStream
import java.io.File
import java.io.IOException
class LocalDnsWorker(private val resolver: suspend (ByteArray) -> ByteArray) : ConcurrentLocalSocketListener(
"LocalDnsThread", File(Core.deviceStorage.noBackupFilesDir, "local_dns_path")), CoroutineScope {
override fun acceptInternal(socket: LocalSocket) = error("big no no")
override fun accept(socket: LocalSocket) {
launch {
socket.use {
val input = DataInputStream(socket.inputStream)
val query = ByteArray(input.readUnsignedShort())
input.read(query)
try {
resolver(query)
} catch (e: Exception) {
when (e) {
is TimeoutCancellationException -> Crashlytics.log(Log.WARN, name, "Resolving timed out")
is CancellationException -> { } // ignore
is IOException -> Crashlytics.log(Log.WARN, name, e.message)
else -> printLog(e)
}
try {
LocalDnsServer.prepareDnsResponse(Message(query)).apply {
header.rcode = Rcode.SERVFAIL
}.toWire()
} catch (_: IOException) {
byteArrayOf() // return empty if cannot parse packet
}
}?.let { response ->
val output = DataOutputStream(socket.outputStream)
output.writeShort(response.size)
output.write(response)
}
}
}
}
}
...@@ -30,8 +30,6 @@ import android.os.Build ...@@ -30,8 +30,6 @@ import android.os.Build
import android.os.ParcelFileDescriptor import android.os.ParcelFileDescriptor
import android.system.ErrnoException import android.system.ErrnoException
import android.system.OsConstants import android.system.OsConstants
import android.util.Log
import com.crashlytics.android.Crashlytics
import com.github.shadowsocks.Core import com.github.shadowsocks.Core
import com.github.shadowsocks.VpnRequestActivity import com.github.shadowsocks.VpnRequestActivity
import com.github.shadowsocks.acl.Acl import com.github.shadowsocks.acl.Acl
...@@ -43,8 +41,6 @@ import com.github.shadowsocks.utils.closeQuietly ...@@ -43,8 +41,6 @@ import com.github.shadowsocks.utils.closeQuietly
import com.github.shadowsocks.utils.int import com.github.shadowsocks.utils.int
import com.github.shadowsocks.utils.printLog import com.github.shadowsocks.utils.printLog
import kotlinx.coroutines.* import kotlinx.coroutines.*
import org.xbill.DNS.Message
import org.xbill.DNS.Rcode
import java.io.* import java.io.*
import java.net.URL import java.net.URL
import android.net.VpnService as BaseVpnService import android.net.VpnService as BaseVpnService
...@@ -87,37 +83,6 @@ class VpnService : BaseVpnService(), BaseService.Interface { ...@@ -87,37 +83,6 @@ class VpnService : BaseVpnService(), BaseService.Interface {
} }
} }
private inner class LocalDnsWorker : ConcurrentLocalSocketListener("LocalDnsThread",
File(Core.deviceStorage.noBackupFilesDir, "local_dns_path")), CoroutineScope {
override fun acceptInternal(socket: LocalSocket) = error("big no no")
override fun accept(socket: LocalSocket) {
launch {
socket.use {
val input = DataInputStream(socket.inputStream)
val query = ByteArray(input.readUnsignedShort())
input.read(query)
try {
DnsResolverCompat.resolveRaw(underlyingNetwork ?: throw IOException("no network"), query)
} catch (e: Exception) {
when (e) {
is TimeoutCancellationException -> Crashlytics.log(Log.WARN, name, "Resolving timed out")
is CancellationException -> { } // ignore
is IOException -> Crashlytics.log(Log.WARN, name, e.message)
else -> printLog(e)
}
LocalDnsServer.prepareDnsResponse(Message(query)).apply {
header.rcode = Rcode.SERVFAIL
}.toWire()
}?.let { response ->
val output = DataOutputStream(socket.outputStream)
output.writeShort(response.size)
output.write(response)
}
}
}
}
}
inner class NullConnectionException : NullPointerException(), BaseService.ExpectedException { inner class NullConnectionException : NullPointerException(), BaseService.ExpectedException {
override fun getLocalizedMessage() = getString(R.string.reboot_required) override fun getLocalizedMessage() = getString(R.string.reboot_required)
} }
...@@ -129,7 +94,6 @@ class VpnService : BaseVpnService(), BaseService.Interface { ...@@ -129,7 +94,6 @@ class VpnService : BaseVpnService(), BaseService.Interface {
private var conn: ParcelFileDescriptor? = null private var conn: ParcelFileDescriptor? = null
private var worker: ProtectWorker? = null private var worker: ProtectWorker? = null
private var localDns: LocalDnsWorker? = null
private var active = false private var active = false
private var metered = false private var metered = false
private var underlyingNetwork: Network? = null private var underlyingNetwork: Network? = null
...@@ -154,8 +118,6 @@ class VpnService : BaseVpnService(), BaseService.Interface { ...@@ -154,8 +118,6 @@ class VpnService : BaseVpnService(), BaseService.Interface {
scope.launch { DefaultNetworkListener.stop(this) } scope.launch { DefaultNetworkListener.stop(this) }
worker?.shutdown(scope) worker?.shutdown(scope)
worker = null worker = null
localDns?.shutdown(scope)
localDns = null
conn?.close() conn?.close()
conn = null conn = null
} }
...@@ -173,11 +135,14 @@ class VpnService : BaseVpnService(), BaseService.Interface { ...@@ -173,11 +135,14 @@ class VpnService : BaseVpnService(), BaseService.Interface {
override suspend fun preInit() = DefaultNetworkListener.start(this) { underlyingNetwork = it } override suspend fun preInit() = DefaultNetworkListener.start(this) { underlyingNetwork = it }
override suspend fun getActiveNetwork() = DefaultNetworkListener.get() override suspend fun getActiveNetwork() = DefaultNetworkListener.get()
override suspend fun resolver(host: String) = DnsResolverCompat.resolve(DefaultNetworkListener.get(), host) override suspend fun resolver(host: String) = DnsResolverCompat.resolve(DefaultNetworkListener.get(), host)
override suspend fun rawResolver(query: ByteArray) =
// no need to listen for network here as this is only used for forwarding local DNS queries.
// retries should be attempted by client.
DnsResolverCompat.resolveRaw(underlyingNetwork ?: throw IOException("no network"), query)
override suspend fun openConnection(url: URL) = DefaultNetworkListener.get().openConnection(url) override suspend fun openConnection(url: URL) = DefaultNetworkListener.get().openConnection(url)
override suspend fun startProcesses(hosts: HostsFile) { override suspend fun startProcesses(hosts: HostsFile) {
worker = ProtectWorker().apply { start() } worker = ProtectWorker().apply { start() }
localDns = LocalDnsWorker().apply { start() }
super.startProcesses(hosts) super.startProcesses(hosts)
sendFd(startVpn()) sendFd(startVpn())
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment