Commit c102406c authored by Mygod's avatar Mygod

Clean up API

parent 724292a5
...@@ -65,7 +65,7 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd ...@@ -65,7 +65,7 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd
private val monitor = ChannelMonitor() private val monitor = ChannelMonitor()
private val job = SupervisorJob() private val job = Job()
override val coroutineContext = Dispatchers.Default + job + CoroutineExceptionHandler { _, t -> printLog(t) } override val coroutineContext = Dispatchers.Default + job + CoroutineExceptionHandler { _, t -> printLog(t) }
fun start(listen: SocketAddress) = DatagramChannel.open().apply { fun start(listen: SocketAddress) = DatagramChannel.open().apply {
...@@ -77,7 +77,7 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd ...@@ -77,7 +77,7 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd
buffer.flip() buffer.flip()
launch { launch {
val reply = resolve(buffer) val reply = resolve(buffer)
while (job.isActive && send(reply, source) <= 0) monitor.wait(this@apply, SelectionKey.OP_WRITE) while (isActive && send(reply, source) <= 0) monitor.wait(this@apply, SelectionKey.OP_WRITE)
} }
} }
} }
...@@ -131,21 +131,25 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd ...@@ -131,21 +131,25 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd
private suspend fun forward(packet: ByteBuffer): ByteBuffer { private suspend fun forward(packet: ByteBuffer): ByteBuffer {
packet.position(0) // the packet might have been parsed, reset to beginning packet.position(0) // the packet might have been parsed, reset to beginning
return withTimeout(TIMEOUT) { return withTimeout(TIMEOUT) {
if (tcp) SocketChannel.open().use { if (tcp) SocketChannel.open().use { channel ->
it.configureBlocking(false) channel.configureBlocking(false)
it.connect(proxy) channel.connect(proxy)
val wrapped = remoteDns.tcpWrap(packet) val wrapped = remoteDns.tcpWrap(packet)
while (job.isActive && !it.finishConnect()) monitor.wait(it, SelectionKey.OP_CONNECT) while (isActive && !channel.finishConnect()) monitor.wait(channel, SelectionKey.OP_CONNECT)
while (job.isActive && it.write(wrapped) >= 0 && wrapped.hasRemaining()) monitor.wait(it, SelectionKey.OP_WRITE) while (isActive && channel.write(wrapped) >= 0 && wrapped.hasRemaining()) {
remoteDns.tcpUnwrap(UDP_PACKET_SIZE, it::read) { monitor.wait(it, SelectionKey.OP_READ) } monitor.wait(channel, SelectionKey.OP_WRITE)
} else DatagramChannel.open().use { }
it.configureBlocking(false) val result = remoteDns.tcpReceiveBuffer(UDP_PACKET_SIZE)
monitor.wait(it, SelectionKey.OP_WRITE) remoteDns.tcpUnwrap(result, channel::read) { monitor.wait(channel, SelectionKey.OP_READ) }
check(it.send(remoteDns.udpWrap(packet), proxy) > 0) result
monitor.wait(it, SelectionKey.OP_READ) } else DatagramChannel.open().use { channel ->
channel.configureBlocking(false)
monitor.wait(channel, SelectionKey.OP_WRITE)
check(channel.send(remoteDns.udpWrap(packet), proxy) > 0)
monitor.wait(channel, SelectionKey.OP_READ)
val result = remoteDns.udpReceiveBuffer(UDP_PACKET_SIZE) val result = remoteDns.udpReceiveBuffer(UDP_PACKET_SIZE)
check(it.receive(result) == proxy) check(channel.receive(result) == proxy)
result.limit(result.position()) result.flip()
remoteDns.udpUnwrap(result) remoteDns.udpUnwrap(result)
result result
}.also { Log.d("forward", "completed $it") } }.also { Log.d("forward", "completed $it") }
......
...@@ -65,42 +65,37 @@ class Socks5Endpoint(host: String, port: Int) { ...@@ -65,42 +65,37 @@ class Socks5Endpoint(host: String, port: Int) {
} }
} }
fun tcpReceiveBuffer(size: Int) = ByteBuffer.allocate(headerReserved + 4 + size) fun tcpReceiveBuffer(size: Int) = ByteBuffer.allocate(headerReserved + 4 + size)
suspend fun tcpUnwrap(size: Int, reader: (ByteBuffer) -> Int, wait: suspend () -> Unit): ByteBuffer { suspend fun tcpUnwrap(buffer: ByteBuffer, reader: (ByteBuffer) -> Int, wait: suspend () -> Unit) {
suspend fun ByteBuffer.readBytes(till: Int) { suspend fun readBytes(till: Int) {
if (position() >= till) return if (buffer.position() >= till) return
while (reader(this) >= 0 && position() < till) wait() while (reader(buffer) >= 0 && buffer.position() < till) wait()
if (position() < till) throw IOException("EOF") if (buffer.position() < till) throw IOException("EOF")
} }
suspend fun ByteBuffer.read(index: Int): Byte { suspend fun read(index: Int): Byte {
readBytes(index + 1) readBytes(index + 1)
return get(index) return buffer[index]
} }
val buffer = tcpReceiveBuffer(size) check(read(0) == Socks5Message.SOCKS_VERSION.toByte()) { "Unsupported SOCKS version" }
check(buffer.read(0) == Socks5Message.SOCKS_VERSION.toByte()) { "Unsupported SOCKS version" } if (read(1) != 0.toByte()) throw IOException("Unsupported authentication ${buffer[1]}")
if (buffer.read(1) != 0.toByte()) throw IOException("Unsupported authentication ${buffer[1]}") check(read(2) == Socks5Message.SOCKS_VERSION.toByte()) { "Unsupported SOCKS version" }
check(buffer.read(2) == Socks5Message.SOCKS_VERSION.toByte()) { "Unsupported SOCKS version" } if (read(3) != 0.toByte()) throw IOException("SOCKS5 server returned error ${buffer[3]}")
if (buffer.read(3) != 0.toByte()) throw IOException("SOCKS5 server returned error ${buffer[3]}") val dataOffset = when (read(5)) {
val dataOffset = when (buffer.read(5)) {
Socks5Message.SOCKS_ATYP_IPV4.toByte() -> 4 Socks5Message.SOCKS_ATYP_IPV4.toByte() -> 4
Socks5Message.SOCKS_ATYP_DOMAINNAME.toByte() -> { Socks5Message.SOCKS_ATYP_DOMAINNAME.toByte() -> 1 + read(6)
buffer.readBytes(4)
1 + buffer[3]
}
Socks5Message.SOCKS_ATYP_IPV6.toByte() -> 16 Socks5Message.SOCKS_ATYP_IPV6.toByte() -> 16
else -> throw IllegalStateException("Unsupported address type ${buffer[5]}") else -> throw IllegalStateException("Unsupported address type ${buffer[5]}")
} + 8 } + 8
buffer.readBytes(dataOffset + 2) readBytes(dataOffset + 2)
buffer.limit(buffer.position()) // store old position to update mark buffer.limit(buffer.position()) // store old position to update mark
buffer.position(dataOffset) buffer.position(dataOffset)
val dataLength = buffer.short.toUShort().toInt() val dataLength = buffer.short.toUShort().toInt()
check(dataLength <= size) { "Buffer too small to contain the message" }
buffer.mark()
val end = buffer.position() + dataLength val end = buffer.position() + dataLength
check(end <= buffer.capacity()) { "Buffer too small to contain the message" }
buffer.mark()
buffer.position(buffer.limit()) // restore old position buffer.position(buffer.limit()) // restore old position
buffer.limit(end) buffer.limit(end)
buffer.readBytes(buffer.limit()) readBytes(buffer.limit())
buffer.reset() buffer.reset()
return buffer
} }
fun udpWrap(packet: ByteBuffer) = ByteBuffer.allocate(3 + dest.size + packet.remaining()).apply { fun udpWrap(packet: ByteBuffer) = ByteBuffer.allocate(3 + dest.size + packet.remaining()).apply {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment