Commit ca25482f authored by Mygod's avatar Mygod

Implement DNS server with NIO

parent 07ec96bf
......@@ -24,8 +24,10 @@ import com.github.shadowsocks.Core.app
import com.github.shadowsocks.acl.Acl
import com.github.shadowsocks.core.R
import com.github.shadowsocks.net.LocalDnsServer
import com.github.shadowsocks.net.Socks5Endpoint
import com.github.shadowsocks.net.Subnet
import com.github.shadowsocks.utils.parseNumericAddress
import com.github.shadowsocks.preference.DataStore
import java.net.InetSocketAddress
import java.util.*
object LocalDnsService {
......@@ -44,7 +46,8 @@ object LocalDnsService {
val data = data
val profile = data.proxy!!.profile
if (!profile.udpdns) servers[this] = LocalDnsServer(this::resolver,
profile.remoteDns.split(",").first().parseNumericAddress()!!).apply {
Socks5Endpoint(profile.remoteDns.split(",").first(), 53),
DataStore.proxyAddress).apply {
when (profile.route) {
Acl.BYPASS_CHN, Acl.BYPASS_LAN_CHN, Acl.GFWLIST, Acl.CUSTOM_RULES -> {
remoteDomainMatcher = googleApisTester
......@@ -53,7 +56,7 @@ object LocalDnsService {
Acl.CHINALIST -> { }
else -> forwardOnly = true
}
start()
start(InetSocketAddress(DataStore.listenAddress, DataStore.portLocalDns))
}
}
......
......@@ -20,40 +20,79 @@
package com.github.shadowsocks.net
import net.sourceforge.jsocks.Socks5Message
import java.io.ByteArrayInputStream
import java.net.*
import java.nio.ByteBuffer
import java.nio.channels.Pipe
import java.nio.channels.SelectableChannel
import java.nio.channels.SelectionKey
import java.nio.channels.Selector
import java.util.concurrent.ConcurrentLinkedQueue
import kotlin.coroutines.resume
import kotlin.coroutines.suspendCoroutine
class Socks5DatagramSocket(proxy: Proxy) : DatagramSocket() {
private val proxy = proxy.address() as InetSocketAddress
override fun send(dp: DatagramPacket) {
val data = ByteBuffer.allocate(6 + dp.address.address.size + dp.length).apply {
// header
putShort(0) // reserved
put(0) // fragment number
put(when (dp.address) {
is Inet4Address -> Socks5Message.SOCKS_ATYP_IPV4
is Inet6Address -> Socks5Message.SOCKS_ATYP_IPV6
else -> throw IllegalStateException("Unsupported address type")
}.toByte())
put(dp.address.address)
putShort(dp.port.toShort())
// data
put(dp.data, dp.offset, dp.length)
}.array()
super.send(DatagramPacket(data, data.size, proxy.address, proxy.port))
class ChannelMonitor : Thread("ChannelMonitor"), AutoCloseable {
private val selector = Selector.open()
private val registrationPipe = Pipe.open()
private val pendingRegistrations = ConcurrentLinkedQueue<Triple<SelectableChannel, Int, (SelectionKey) -> Unit>>()
@Volatile
private var running = true
private fun registerInternal(channel: SelectableChannel, ops: Int, block: (SelectionKey) -> Unit) =
channel.register(selector, ops, block)
init {
registrationPipe.source().apply {
configureBlocking(false)
registerInternal(this, SelectionKey.OP_READ) {
val junk = ByteBuffer.allocateDirect(1)
while (read(junk) > 0) {
val (channel, ops, block) = pendingRegistrations.remove()
registerInternal(channel, ops, block)
junk.clear()
}
}
}
start()
}
fun register(channel: SelectableChannel, ops: Int, block: (SelectionKey) -> Unit) {
pendingRegistrations.add(Triple(channel, ops, block))
val junk = ByteBuffer.allocateDirect(1)
while (registrationPipe.sink().write(junk) == 0);
}
suspend fun wait(channel: SelectableChannel, ops: Int) = suspendCoroutine<Unit> { continuation ->
register(channel, ops) {
it.interestOps(0)
continuation.resume(Unit)
}
}
suspend fun waitWhile(channel: SelectableChannel, ops: Int, condition: () -> Boolean) {
if (condition()) suspendCoroutine<Unit> { continuation ->
register(channel, ops) {
if (condition()) return@register
it.interestOps(0)
continuation.resume(Unit)
}
}
}
override fun run() {
while (running) {
if (selector.select() <= 0) continue
val iterator = selector.selectedKeys().iterator()
while (iterator.hasNext()) {
val key = iterator.next()
iterator.remove()
(key.attachment() as (SelectionKey) -> Unit)(key)
}
}
}
override fun receive(dp: DatagramPacket) {
super.receive(dp)
check(proxy.address == dp.address && proxy.port == dp.port) { "Unexpected packet" }
val stream = ByteArrayInputStream(dp.data, dp.offset, dp.length)
val msg = Socks5Message(stream)
dp.port = msg.port
dp.address = msg.inetAddress
val remaining = stream.available()
dp.setData(dp.data, dp.offset + dp.length - remaining, remaining)
override fun close() {
running = false
selector.wakeup()
join()
selector.keys().forEach { it.channel().close() }
}
}
......@@ -26,7 +26,6 @@ import android.net.Network
import android.net.NetworkCapabilities
import android.net.NetworkRequest
import android.os.Build
import android.widget.Toast
import androidx.core.content.getSystemService
import com.crashlytics.android.Crashlytics
import com.github.shadowsocks.Core.app
......
......@@ -33,6 +33,7 @@ import com.github.shadowsocks.utils.responseLength
import kotlinx.coroutines.*
import java.io.IOException
import java.net.HttpURLConnection
import java.net.Proxy
import java.net.URL
/**
......@@ -83,7 +84,7 @@ class HttpsTest : ViewModel() {
else -> "www.google.com"
}, "/generate_204")
val conn = (if (DataStore.serviceMode != Key.modeVpn) {
url.openConnection(DataStore.proxy)
url.openConnection(Proxy(Proxy.Type.SOCKS, DataStore.proxyAddress))
} else url.openConnection()) as HttpURLConnection
conn.setRequestProperty("Connection", "close")
conn.instanceFollowRedirects = false
......
......@@ -20,18 +20,16 @@
package com.github.shadowsocks.net
import android.os.ParcelFileDescriptor
import com.github.shadowsocks.preference.DataStore
import com.github.shadowsocks.utils.parseNumericAddress
import android.util.Log
import com.github.shadowsocks.utils.printLog
import com.github.shadowsocks.utils.shutdown
import kotlinx.coroutines.*
import org.xbill.DNS.*
import java.io.*
import java.io.IOException
import java.net.*
import java.nio.ByteBuffer
import java.util.*
import java.util.concurrent.ConcurrentHashMap
import java.nio.channels.DatagramChannel
import java.nio.channels.SelectionKey
import java.nio.channels.SocketChannel
/**
* A simple DNS conditional forwarder.
......@@ -43,7 +41,7 @@ import java.util.concurrent.ConcurrentHashMap
* https://github.com/shadowsocks/overture/tree/874f22613c334a3b78e40155a55479b7b69fee04
*/
class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAddress>,
private val remoteDns: InetAddress) : SocketListener("LocalDnsServer"), CoroutineScope {
private val remoteDns: Socks5Endpoint, private val proxy: SocketAddress) : CoroutineScope {
/**
* Forward all requests to remote and ignore localResolver.
*/
......@@ -64,37 +62,28 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd
private const val TTL = 120L
private const val UDP_PACKET_SIZE = 512
}
private val socket = DatagramSocket(DataStore.portLocalDns, DataStore.listenAddress.parseNumericAddress())
private val DatagramSocket.fileDescriptor get() = ParcelFileDescriptor.fromDatagramSocket(this).fileDescriptor
override val fileDescriptor get() = socket.fileDescriptor
private val proxy = DataStore.proxy
private val monitor = ChannelMonitor()
private val activeFds = Collections.newSetFromMap(ConcurrentHashMap<FileDescriptor, Boolean>())
private val job = SupervisorJob()
override val coroutineContext get() = Dispatchers.Default + job + CoroutineExceptionHandler { _, t -> printLog(t) }
override val coroutineContext = Dispatchers.Default + job + CoroutineExceptionHandler { _, t -> printLog(t) }
override fun run() {
while (running) {
val packet = DatagramPacket(ByteArray(UDP_PACKET_SIZE), 0, UDP_PACKET_SIZE)
try {
socket.receive(packet)
launch {
resolve(packet) // this method should also put the reply in the packet
socket.send(packet)
}
} catch (e: RuntimeException) {
e.printStackTrace()
fun start(listen: SocketAddress) = DatagramChannel.open().apply {
configureBlocking(false)
socket().bind(listen)
monitor.register(this, SelectionKey.OP_READ) {
val buffer = ByteBuffer.allocate(UDP_PACKET_SIZE)
val source = receive(buffer)!!
buffer.flip()
launch {
val reply = resolve(buffer)
while (send(reply, source) <= 0) monitor.wait(this@apply, SelectionKey.OP_WRITE)
}
}
socket.close()
}
private suspend fun <T> io(block: suspend CoroutineScope.() -> T) =
withTimeout(TIMEOUT) { withContext(Dispatchers.IO, block) }
private suspend fun resolve(packet: DatagramPacket) {
private suspend fun resolve(packet: ByteBuffer): ByteBuffer {
val request = try {
Message(ByteBuffer.wrap(packet.data, packet.offset, packet.length))
Message(packet)
} catch (e: IOException) { // we cannot parse the message, do not attempt to handle it at all
printLog(e)
return forward(packet)
......@@ -106,7 +95,7 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd
val host = question.name.toString(true)
if (remoteDomainMatcher?.containsMatchIn(host) == true) return forward(packet)
val localResults = try {
io { localResolver(host) }
withTimeout(TIMEOUT) { withContext(Dispatchers.IO) { localResolver(host) } }
} catch (_: TimeoutCancellationException) {
return forward(packet)
} catch (_: UnknownHostException) {
......@@ -124,8 +113,7 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd
is Inet6Address -> AAAARecord(request.question.name, DClass.IN, TTL, address)
else -> throw IllegalStateException("Unsupported address $address")
}, Section.ANSWER)
val wire = response.toWire()
return packet.setData(wire, 0, wire.size)
return ByteBuffer.wrap(response.toWire())
}
return forward(packet)
} catch (e: IOException) {
......@@ -135,54 +123,38 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd
response.header.setFlag(Flags.QR.toInt())
if (request.header.getFlag(Flags.RD.toInt())) response.header.setFlag(Flags.RD.toInt())
response.addRecord(request.question, Section.QUESTION)
val wire = response.toWire()
return packet.setData(wire, 0, wire.size)
return ByteBuffer.wrap(response.toWire())
}
}
private suspend fun forward(packet: DatagramPacket) = if (tcp) Socket(proxy).useFd {
it.connect(InetSocketAddress(remoteDns, 53))
DataOutputStream(it.getOutputStream()).apply {
writeShort(packet.length)
write(packet.data, packet.offset, packet.length)
flush()
}
DataInputStream(it.getInputStream()).apply {
packet.length = readUnsignedShort()
readFully(packet.data, packet.offset, packet.length)
}
} else Socks5DatagramSocket(proxy).useFd {
val address = packet.address // we are reusing the packet, save it first
val port = packet.port
packet.address = remoteDns
packet.port = 53
it.send(packet)
packet.length = UDP_PACKET_SIZE
it.receive(packet)
packet.address = address
packet.port = port
}
private suspend fun <T : Closeable> T.useFd(block: (T) -> Unit) {
val fd = when (this) {
is Socket -> ParcelFileDescriptor.fromSocket(this).fileDescriptor
is DatagramSocket -> fileDescriptor
else -> throw IllegalStateException("Unsupported type $javaClass for obtaining FileDescriptor")
}
try {
activeFds += fd
io { use(block) }
} finally {
fd.shutdown()
activeFds -= fd
private suspend fun forward(packet: ByteBuffer): ByteBuffer {
packet.position(0) // the packet might have been parsed, reset to beginning
return withTimeout(TIMEOUT) {
if (tcp) SocketChannel.open().use {
it.configureBlocking(false)
it.connect(proxy)
val wrapped = remoteDns.tcpWrap(packet)
while (!it.finishConnect()) monitor.wait(it, SelectionKey.OP_CONNECT)
// monitor.waitWhile(it, SelectionKey.OP_WRITE) { it.write(wrapped) >= 0 && wrapped.hasRemaining() }
while (it.write(wrapped) >= 0 && wrapped.hasRemaining()) monitor.wait(it, SelectionKey.OP_WRITE)
remoteDns.tcpUnwrap(UDP_PACKET_SIZE, it::read) { monitor.wait(it, SelectionKey.OP_READ) }
} else DatagramChannel.open().use {
it.configureBlocking(false)
monitor.wait(it, SelectionKey.OP_WRITE)
check(it.send(remoteDns.udpWrap(packet), proxy) > 0)
monitor.wait(it, SelectionKey.OP_READ)
val result = remoteDns.udpReceiveBuffer(UDP_PACKET_SIZE)
check(it.receive(result) == proxy)
result.limit(result.position())
remoteDns.udpUnwrap(result)
result
}.also { Log.d("forward", "completed $it") }
}
}
suspend fun shutdown() {
running = false
job.cancel()
close()
activeFds.forEach { it.shutdown() }
monitor.close()
job.join()
}
}
/*******************************************************************************
* *
* Copyright (C) 2019 by Max Lv <max.c.lv@gmail.com> *
* Copyright (C) 2019 by Mygod Studio <contact-shadowsocks-android@mygod.be> *
* *
* This program is free software: you can redistribute it and/or modify *
* it under the terms of the GNU General Public License as published by *
* the Free Software Foundation, either version 3 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU General Public License for more details. *
* *
* You should have received a copy of the GNU General Public License *
* along with this program. If not, see <http://www.gnu.org/licenses/>. *
* *
*******************************************************************************/
package com.github.shadowsocks.net
import com.github.shadowsocks.utils.parseNumericAddress
import net.sourceforge.jsocks.Socks4Message
import net.sourceforge.jsocks.Socks5Message
import java.io.IOException
import java.net.Inet4Address
import java.net.Inet6Address
import java.nio.ByteBuffer
import kotlin.math.max
class Socks5Endpoint(host: String, port: Int) {
private val dest = host.parseNumericAddress().let { numeric ->
val bytes = numeric?.address ?: host.toByteArray().apply { check(size < 256) { "Hostname too long" } }
val type = when (numeric) {
null -> Socks5Message.SOCKS_ATYP_DOMAINNAME
is Inet4Address -> Socks5Message.SOCKS_ATYP_IPV4
is Inet6Address -> Socks5Message.SOCKS_ATYP_IPV6
else -> throw IllegalStateException("Unsupported address type")
}
ByteBuffer.allocate(bytes.size + (if (numeric == null) 1 else 0) + 3).apply {
put(type.toByte())
if (numeric == null) put(bytes.size.toByte())
put(bytes)
putShort(port.toShort())
}
}.array()
private val headerReserved = max(3 + 3 + 16, 3 + dest.size)
fun tcpWrap(message: ByteBuffer): ByteBuffer {
check(message.remaining() < 65536) { "TCP message too large" }
return ByteBuffer.allocate(8 + dest.size + message.remaining()).apply {
put(Socks5Message.SOCKS_VERSION.toByte())
put(1) // nmethods
put(0) // no authentication required
// header
put(Socks5Message.SOCKS_VERSION.toByte())
put(Socks4Message.REQUEST_CONNECT.toByte())
put(0) // reserved
put(dest)
// data
putShort(message.remaining().toShort())
put(message)
flip()
}
}
fun tcpReceiveBuffer(size: Int) = ByteBuffer.allocate(headerReserved + 4 + size)
suspend fun tcpUnwrap(size: Int, reader: (ByteBuffer) -> Int, wait: suspend () -> Unit): ByteBuffer {
suspend fun ByteBuffer.readBytes(till: Int) {
if (position() >= till) return
while (reader(this) >= 0 && position() < till) wait()
if (position() < till) throw IOException("EOF")
}
suspend fun ByteBuffer.read(index: Int): Byte {
readBytes(index + 1)
return get(index)
}
val buffer = tcpReceiveBuffer(size)
check(buffer.read(0) == Socks5Message.SOCKS_VERSION.toByte()) { "Unsupported SOCKS version" }
if (buffer.read(1) != 0.toByte()) throw IOException("Unsupported authentication ${buffer[1]}")
check(buffer.read(2) == Socks5Message.SOCKS_VERSION.toByte()) { "Unsupported SOCKS version" }
if (buffer.read(3) != 0.toByte()) throw IOException("SOCKS5 server returned error ${buffer[3]}")
val dataOffset = when (buffer.read(5)) {
Socks5Message.SOCKS_ATYP_IPV4.toByte() -> 4
Socks5Message.SOCKS_ATYP_DOMAINNAME.toByte() -> {
buffer.readBytes(4)
1 + buffer[3]
}
Socks5Message.SOCKS_ATYP_IPV6.toByte() -> 16
else -> throw IllegalStateException("Unsupported address type ${buffer[5]}")
} + 8
buffer.readBytes(dataOffset + 2)
buffer.limit(buffer.position()) // store old position to update mark
buffer.position(dataOffset)
val dataLength = buffer.short.toUShort().toInt()
check(dataLength <= size) { "Buffer too small to contain the message" }
buffer.mark()
val end = buffer.position() + dataLength
buffer.position(buffer.limit()) // restore old position
buffer.limit(end)
buffer.readBytes(buffer.limit())
buffer.reset()
return buffer
}
fun udpWrap(packet: ByteBuffer) = ByteBuffer.allocate(3 + dest.size + packet.remaining()).apply {
// header
putShort(0) // reserved
put(0) // fragment number
put(dest)
// data
put(packet)
flip()
}
fun udpReceiveBuffer(size: Int) = ByteBuffer.allocate(headerReserved + size)
fun udpUnwrap(packet: ByteBuffer) {
packet.position(3)
packet.position(6 + when (packet.get()) {
Socks5Message.SOCKS_ATYP_IPV4.toByte() -> 4
Socks5Message.SOCKS_ATYP_DOMAINNAME.toByte() -> 1 + packet.get()
Socks5Message.SOCKS_ATYP_IPV6.toByte() -> 16
else -> throw IllegalStateException("Unsupported address type")
})
packet.mark()
}
}
......@@ -90,7 +90,7 @@ object DataStore : OnPreferenceDataStoreChangeListener {
var portProxy: Int
get() = getLocalPort(Key.portProxy, 1080)
set(value) = publicStore.putString(Key.portProxy, value.toString())
val proxy get() = Proxy(Proxy.Type.SOCKS, InetSocketAddress("127.0.0.1", portProxy))
val proxyAddress get() = InetSocketAddress("127.0.0.1", portProxy)
var portLocalDns: Int
get() = getLocalPort(Key.portLocalDns, 5450)
set(value) = publicStore.putString(Key.portLocalDns, value.toString())
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment