Commit 00dfb925 authored by Mygod's avatar Mygod

Fix async resolver

parent 0faff461
...@@ -88,46 +88,48 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd ...@@ -88,46 +88,48 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd
printLog(e) printLog(e)
return forward(packet) return forward(packet)
} }
val remote = coroutineScope { async { forward(packet) } } return coroutineScope {
try { val remote = async { forward(packet) }
if (forwardOnly || request.header.opcode != Opcode.QUERY) return remote.await() try {
val question = request.question if (forwardOnly || request.header.opcode != Opcode.QUERY) return@coroutineScope remote.await()
if (question?.type != Type.A) return remote.await() val question = request.question
val host = question.name.toString(true) if (question?.type != Type.A) return@coroutineScope remote.await()
if (remoteDomainMatcher?.containsMatchIn(host) == true) return remote.await() val host = question.name.toString(true)
val localResults = try { if (remoteDomainMatcher?.containsMatchIn(host) == true) return@coroutineScope remote.await()
withTimeout(TIMEOUT) { GlobalScope.async(Dispatchers.IO) { localResolver(host) }.await() } val localResults = try {
} catch (_: TimeoutCancellationException) { withTimeout(TIMEOUT) { GlobalScope.async(Dispatchers.IO) { localResolver(host) }.await() }
Log.w("LocalDnsServer", "Local resolving timed out, falling back to remote resolving") } catch (_: TimeoutCancellationException) {
return remote.await() Log.w("LocalDnsServer", "Local resolving timed out, falling back to remote resolving")
} catch (_: UnknownHostException) { return@coroutineScope remote.await()
return remote.await() } catch (_: UnknownHostException) {
} return@coroutineScope remote.await()
if (localResults.isEmpty()) return remote.await() }
if (localIpMatcher.isEmpty() || localIpMatcher.any { subnet -> localResults.any(subnet::matches) }) { if (localResults.isEmpty()) return@coroutineScope remote.await()
if (localIpMatcher.isEmpty() || localIpMatcher.any { subnet -> localResults.any(subnet::matches) }) {
remote.cancel()
val response = Message(request.header.id)
response.header.setFlag(Flags.QR.toInt()) // this is a response
if (request.header.getFlag(Flags.RD.toInt())) response.header.setFlag(Flags.RD.toInt())
response.header.setFlag(Flags.RA.toInt()) // recursion available
response.addRecord(request.question, Section.QUESTION)
for (address in localResults) response.addRecord(when (address) {
is Inet4Address -> ARecord(request.question.name, DClass.IN, TTL, address)
is Inet6Address -> AAAARecord(request.question.name, DClass.IN, TTL, address)
else -> throw IllegalStateException("Unsupported address $address")
}, Section.ANSWER)
return@coroutineScope ByteBuffer.wrap(response.toWire())
}
return@coroutineScope remote.await()
} catch (e: IOException) {
remote.cancel() remote.cancel()
printLog(e)
val response = Message(request.header.id) val response = Message(request.header.id)
response.header.setFlag(Flags.QR.toInt()) // this is a response response.header.rcode = Rcode.SERVFAIL
response.header.setFlag(Flags.QR.toInt())
if (request.header.getFlag(Flags.RD.toInt())) response.header.setFlag(Flags.RD.toInt()) if (request.header.getFlag(Flags.RD.toInt())) response.header.setFlag(Flags.RD.toInt())
response.header.setFlag(Flags.RA.toInt()) // recursion available
response.addRecord(request.question, Section.QUESTION) response.addRecord(request.question, Section.QUESTION)
for (address in localResults) response.addRecord(when (address) { return@coroutineScope ByteBuffer.wrap(response.toWire())
is Inet4Address -> ARecord(request.question.name, DClass.IN, TTL, address)
is Inet6Address -> AAAARecord(request.question.name, DClass.IN, TTL, address)
else -> throw IllegalStateException("Unsupported address $address")
}, Section.ANSWER)
return ByteBuffer.wrap(response.toWire())
} }
return remote.await()
} catch (e: IOException) {
remote.cancel()
printLog(e)
val response = Message(request.header.id)
response.header.rcode = Rcode.SERVFAIL
response.header.setFlag(Flags.QR.toInt())
if (request.header.getFlag(Flags.RD.toInt())) response.header.setFlag(Flags.RD.toInt())
response.addRecord(request.question, Section.QUESTION)
return ByteBuffer.wrap(response.toWire())
} }
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment