Commit f8142a6f authored by Mygod's avatar Mygod

Refine code style

parent 38c78a90
...@@ -70,14 +70,16 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd ...@@ -70,14 +70,16 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd
suspend fun start(listen: SocketAddress) = DatagramChannel.open().apply { suspend fun start(listen: SocketAddress) = DatagramChannel.open().apply {
configureBlocking(false) configureBlocking(false)
socket().bind(listen) socket().bind(listen)
monitor.register(this, SelectionKey.OP_READ) { monitor.register(this, SelectionKey.OP_READ) { handlePacket(this) }
}
private fun handlePacket(channel: DatagramChannel) {
val buffer = ByteBuffer.allocateDirect(UDP_PACKET_SIZE) val buffer = ByteBuffer.allocateDirect(UDP_PACKET_SIZE)
val source = receive(buffer)!! val source = channel.receive(buffer)!!
buffer.flip() buffer.flip()
this@LocalDnsServer.launch { launch {
val reply = resolve(buffer) val reply = resolve(buffer)
while (send(reply, source) <= 0) monitor.wait(this@apply, SelectionKey.OP_WRITE) while (channel.send(reply, source) <= 0) monitor.wait(channel, SelectionKey.OP_WRITE)
}
} }
} }
...@@ -89,7 +91,7 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd ...@@ -89,7 +91,7 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd
return forward(packet) return forward(packet)
} }
return coroutineScope { return coroutineScope {
val remote = async { forward(packet) } val remote = async { withTimeout(TIMEOUT) { forward(packet) } }
try { try {
if (forwardOnly || request.header.opcode != Opcode.QUERY) return@coroutineScope remote.await() if (forwardOnly || request.header.opcode != Opcode.QUERY) return@coroutineScope remote.await()
val question = request.question val question = request.question
...@@ -135,8 +137,7 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd ...@@ -135,8 +137,7 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd
private suspend fun forward(packet: ByteBuffer): ByteBuffer { private suspend fun forward(packet: ByteBuffer): ByteBuffer {
packet.position(0) // the packet might have been parsed, reset to beginning packet.position(0) // the packet might have been parsed, reset to beginning
return withTimeout(TIMEOUT) { return if (tcp) SocketChannel.open().use { channel ->
if (tcp) SocketChannel.open().use { channel ->
channel.configureBlocking(false) channel.configureBlocking(false)
channel.connect(proxy) channel.connect(proxy)
val wrapped = remoteDns.tcpWrap(packet) val wrapped = remoteDns.tcpWrap(packet)
...@@ -159,7 +160,6 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd ...@@ -159,7 +160,6 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd
result result
} }
} }
}
fun shutdown(scope: CoroutineScope) { fun shutdown(scope: CoroutineScope) {
job.cancel() job.cancel()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment