Commit f8142a6f authored by Mygod's avatar Mygod

Refine code style

parent 38c78a90
......@@ -29,8 +29,8 @@ import java.nio.channels.*
class ChannelMonitor : Thread("ChannelMonitor") {
private data class Registration(val channel: SelectableChannel,
val ops: Int,
val listener: (SelectionKey) -> Unit) {
val ops: Int,
val listener: (SelectionKey) -> Unit) {
val result = CompletableDeferred<SelectionKey>()
}
......
......@@ -70,14 +70,16 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd
suspend fun start(listen: SocketAddress) = DatagramChannel.open().apply {
configureBlocking(false)
socket().bind(listen)
monitor.register(this, SelectionKey.OP_READ) {
val buffer = ByteBuffer.allocateDirect(UDP_PACKET_SIZE)
val source = receive(buffer)!!
buffer.flip()
this@LocalDnsServer.launch {
val reply = resolve(buffer)
while (send(reply, source) <= 0) monitor.wait(this@apply, SelectionKey.OP_WRITE)
}
monitor.register(this, SelectionKey.OP_READ) { handlePacket(this) }
}
private fun handlePacket(channel: DatagramChannel) {
val buffer = ByteBuffer.allocateDirect(UDP_PACKET_SIZE)
val source = channel.receive(buffer)!!
buffer.flip()
launch {
val reply = resolve(buffer)
while (channel.send(reply, source) <= 0) monitor.wait(channel, SelectionKey.OP_WRITE)
}
}
......@@ -89,7 +91,7 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd
return forward(packet)
}
return coroutineScope {
val remote = async { forward(packet) }
val remote = async { withTimeout(TIMEOUT) { forward(packet) } }
try {
if (forwardOnly || request.header.opcode != Opcode.QUERY) return@coroutineScope remote.await()
val question = request.question
......@@ -135,29 +137,27 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd
private suspend fun forward(packet: ByteBuffer): ByteBuffer {
packet.position(0) // the packet might have been parsed, reset to beginning
return withTimeout(TIMEOUT) {
if (tcp) SocketChannel.open().use { channel ->
channel.configureBlocking(false)
channel.connect(proxy)
val wrapped = remoteDns.tcpWrap(packet)
while (!channel.finishConnect()) monitor.wait(channel, SelectionKey.OP_CONNECT)
while (channel.write(wrapped) >= 0 && wrapped.hasRemaining()) {
monitor.wait(channel, SelectionKey.OP_WRITE)
}
val result = remoteDns.tcpReceiveBuffer(UDP_PACKET_SIZE)
remoteDns.tcpUnwrap(result, channel::read) { monitor.wait(channel, SelectionKey.OP_READ) }
result
} else DatagramChannel.open().use { channel ->
channel.configureBlocking(false)
return if (tcp) SocketChannel.open().use { channel ->
channel.configureBlocking(false)
channel.connect(proxy)
val wrapped = remoteDns.tcpWrap(packet)
while (!channel.finishConnect()) monitor.wait(channel, SelectionKey.OP_CONNECT)
while (channel.write(wrapped) >= 0 && wrapped.hasRemaining()) {
monitor.wait(channel, SelectionKey.OP_WRITE)
check(channel.send(remoteDns.udpWrap(packet), proxy) > 0)
monitor.wait(channel, SelectionKey.OP_READ)
val result = remoteDns.udpReceiveBuffer(UDP_PACKET_SIZE)
check(channel.receive(result) == proxy)
result.flip()
remoteDns.udpUnwrap(result)
result
}
val result = remoteDns.tcpReceiveBuffer(UDP_PACKET_SIZE)
remoteDns.tcpUnwrap(result, channel::read) { monitor.wait(channel, SelectionKey.OP_READ) }
result
} else DatagramChannel.open().use { channel ->
channel.configureBlocking(false)
monitor.wait(channel, SelectionKey.OP_WRITE)
check(channel.send(remoteDns.udpWrap(packet), proxy) > 0)
monitor.wait(channel, SelectionKey.OP_READ)
val result = remoteDns.udpReceiveBuffer(UDP_PACKET_SIZE)
check(channel.receive(result) == proxy)
result.flip()
remoteDns.udpUnwrap(result)
result
}
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment