Commit add3aa4f authored by Mygod's avatar Mygod

Fix DNS resolving

parent b2e7ee54
/*******************************************************************************
* *
* Copyright (C) 2019 by Max Lv <max.c.lv@gmail.com> *
* Copyright (C) 2019 by Mygod Studio <contact-shadowsocks-android@mygod.be> *
* *
* This program is free software: you can redistribute it and/or modify *
* it under the terms of the GNU General Public License as published by *
* the Free Software Foundation, either version 3 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU General Public License for more details. *
* *
* You should have received a copy of the GNU General Public License *
* along with this program. If not, see <http://www.gnu.org/licenses/>. *
* *
*******************************************************************************/
package com.github.shadowsocks.net package com.github.shadowsocks.net
import android.os.ParcelFileDescriptor import android.os.ParcelFileDescriptor
import android.util.Log
import com.github.shadowsocks.preference.DataStore import com.github.shadowsocks.preference.DataStore
import com.github.shadowsocks.utils.parseNumericAddress import com.github.shadowsocks.utils.parseNumericAddress
import com.github.shadowsocks.utils.printLog import com.github.shadowsocks.utils.printLog
import com.github.shadowsocks.utils.shutdown import com.github.shadowsocks.utils.shutdown
import kotlinx.coroutines.* import kotlinx.coroutines.*
import net.sourceforge.jsocks.Socks5DatagramSocket
import net.sourceforge.jsocks.Socks5Proxy
import org.xbill.DNS.* import org.xbill.DNS.*
import java.io.Closeable import java.io.*
import java.io.FileDescriptor
import java.io.IOException
import java.net.* import java.net.*
import java.nio.ByteBuffer import java.nio.ByteBuffer
import java.util.* import java.util.*
...@@ -47,14 +62,12 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd ...@@ -47,14 +62,12 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd
* so we suppose Android apps should not care about TTL either. * so we suppose Android apps should not care about TTL either.
*/ */
private const val TTL = 120L private const val TTL = 120L
private const val UDP_PACKET_SIZE = 1500 private const val UDP_PACKET_SIZE = 512
} }
private val socket = DatagramSocket(DataStore.portLocalDns, DataStore.listenAddress.parseNumericAddress()) private val socket = DatagramSocket(DataStore.portLocalDns, DataStore.listenAddress.parseNumericAddress())
private val DatagramSocket.fileDescriptor get() = ParcelFileDescriptor.fromDatagramSocket(this).fileDescriptor private val DatagramSocket.fileDescriptor get() = ParcelFileDescriptor.fromDatagramSocket(this).fileDescriptor
override val fileDescriptor get() = socket.fileDescriptor override val fileDescriptor get() = socket.fileDescriptor
private val tcpProxy = DataStore.proxy private val proxy = DataStore.proxy
private val udpProxy = Socks5Proxy("127.0.0.1", DataStore.portProxy)
private val activeFds = Collections.newSetFromMap(ConcurrentHashMap<FileDescriptor, Boolean>()) private val activeFds = Collections.newSetFromMap(ConcurrentHashMap<FileDescriptor, Boolean>())
private val job = SupervisorJob() private val job = SupervisorJob()
...@@ -65,7 +78,7 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd ...@@ -65,7 +78,7 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd
val packet = DatagramPacket(ByteArray(UDP_PACKET_SIZE), 0, UDP_PACKET_SIZE) val packet = DatagramPacket(ByteArray(UDP_PACKET_SIZE), 0, UDP_PACKET_SIZE)
try { try {
socket.receive(packet) socket.receive(packet)
launch(start = CoroutineStart.UNDISPATCHED) { launch {
resolve(packet) // this method should also put the reply in the packet resolve(packet) // this method should also put the reply in the packet
socket.send(packet) socket.send(packet)
} }
...@@ -73,69 +86,81 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd ...@@ -73,69 +86,81 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd
e.printStackTrace() e.printStackTrace()
} }
} }
socket.close()
} }
private suspend fun <T> io(block: suspend CoroutineScope.() -> T) = private suspend fun <T> io(block: suspend CoroutineScope.() -> T) =
withTimeout(TIMEOUT) { withContext(Dispatchers.IO, block) } withTimeout(TIMEOUT) { withContext(Dispatchers.IO, block) }
private suspend fun resolve(packet: DatagramPacket) { private suspend fun resolve(packet: DatagramPacket) {
if (forwardOnly) return forward(packet)
val request = try { val request = try {
Message(ByteBuffer.wrap(packet.data, packet.offset, packet.length)) Message(ByteBuffer.wrap(packet.data, packet.offset, packet.length))
} catch (e: IOException) { } catch (e: IOException) { // we cannot parse the message, do not attempt to handle it at all
printLog(e) printLog(e)
return forward(packet) return forward(packet)
} }
if (request.header.opcode != Opcode.QUERY || request.header.rcode != Rcode.NOERROR) return forward(packet) try {
val question = request.question if (forwardOnly || request.header.opcode != Opcode.QUERY) return forward(packet)
if (question?.type != Type.A) return forward(packet) val question = request.question
val host = question.name.toString(true) if (question?.type != Type.A) return forward(packet)
if (remoteDomainMatcher?.containsMatchIn(host) == true) return forward(packet) val host = question.name.toString(true)
val localResults = try { if (remoteDomainMatcher?.containsMatchIn(host) == true) return forward(packet)
io { localResolver(host) } val localResults = try {
} catch (_: TimeoutCancellationException) { io { localResolver(host) }
return forward(packet) } catch (_: TimeoutCancellationException) {
} catch (_: UnknownHostException) { return forward(packet)
} catch (_: UnknownHostException) {
return forward(packet)
}
if (localResults.isEmpty()) return forward(packet)
if (localIpMatcher.isEmpty() || localIpMatcher.any { subnet -> localResults.any(subnet::matches) }) {
val response = Message(request.header.id)
response.header.setFlag(Flags.QR.toInt()) // this is a response
if (request.header.getFlag(Flags.RD.toInt())) response.header.setFlag(Flags.RD.toInt())
response.header.setFlag(Flags.RA.toInt()) // recursion available
response.addRecord(request.question, Section.QUESTION)
for (address in localResults) response.addRecord(when (address) {
is Inet4Address -> ARecord(request.question.name, DClass.IN, TTL, address)
is Inet6Address -> AAAARecord(request.question.name, DClass.IN, TTL, address)
else -> throw IllegalStateException("Unsupported address $address")
}, Section.ANSWER)
val wire = response.toWire()
return packet.setData(wire, 0, wire.size)
}
return forward(packet) return forward(packet)
} } catch (e: IOException) {
if (localResults.isEmpty()) return forward(packet) printLog(e)
if (localIpMatcher.isEmpty() || localIpMatcher.any { subnet -> localResults.any(subnet::matches) }) {
Log.d("DNS", "$host (local) -> $localResults")
val response = Message(request.header.id) val response = Message(request.header.id)
response.header.setFlag(Flags.QR.toInt()) // this is a response response.header.rcode = Rcode.SERVFAIL
response.header.setFlag(Flags.QR.toInt())
if (request.header.getFlag(Flags.RD.toInt())) response.header.setFlag(Flags.RD.toInt()) if (request.header.getFlag(Flags.RD.toInt())) response.header.setFlag(Flags.RD.toInt())
response.header.setFlag(Flags.RA.toInt()) // recursion available
response.addRecord(request.question, Section.QUESTION) response.addRecord(request.question, Section.QUESTION)
for (address in localResults) response.addRecord(when (address) {
is Inet4Address -> ARecord(request.question.name, DClass.IN, TTL, address)
is Inet6Address -> AAAARecord(request.question.name, DClass.IN, TTL, address)
else -> throw IllegalStateException("Unsupported address $address")
}, Section.ANSWER)
val wire = response.toWire() val wire = response.toWire()
return packet.setData(wire, 0, wire.size) return packet.setData(wire, 0, wire.size)
} }
return forward(packet)
} }
private suspend fun forward(packet: DatagramPacket) = if (tcp) Socket(tcpProxy).useFd { private suspend fun forward(packet: DatagramPacket) = if (tcp) Socket(proxy).useFd {
it.connect(InetSocketAddress(remoteDns, 53), 53) it.connect(InetSocketAddress(remoteDns, 53))
it.getOutputStream().apply { DataOutputStream(it.getOutputStream()).apply {
writeShort(packet.length)
write(packet.data, packet.offset, packet.length) write(packet.data, packet.offset, packet.length)
flush() flush()
} }
val read = it.getInputStream().read(packet.data, 0, UDP_PACKET_SIZE) DataInputStream(it.getInputStream()).apply {
packet.length = if (read < 0) 0 else read packet.length = readUnsignedShort()
} else Socks5DatagramSocket(udpProxy, 0, null).useFd { readFully(packet.data, packet.offset, packet.length)
}
} else Socks5DatagramSocket(proxy).useFd {
val address = packet.address // we are reusing the packet, save it first val address = packet.address // we are reusing the packet, save it first
val port = packet.port
packet.address = remoteDns packet.address = remoteDns
packet.port = 53 packet.port = 53
packet.toString()
Log.d("DNS", "Sending $packet")
it.send(packet) it.send(packet)
Log.d("DNS", "Receiving $packet") packet.length = UDP_PACKET_SIZE
it.receive(packet) it.receive(packet)
Log.d("DNS", "Finished $packet")
packet.address = address packet.address = address
packet.port = port
} }
private suspend fun <T : Closeable> T.useFd(block: (T) -> Unit) { private suspend fun <T : Closeable> T.useFd(block: (T) -> Unit) {
......
/*******************************************************************************
* *
* Copyright (C) 2019 by Max Lv <max.c.lv@gmail.com> *
* Copyright (C) 2019 by Mygod Studio <contact-shadowsocks-android@mygod.be> *
* *
* This program is free software: you can redistribute it and/or modify *
* it under the terms of the GNU General Public License as published by *
* the Free Software Foundation, either version 3 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU General Public License for more details. *
* *
* You should have received a copy of the GNU General Public License *
* along with this program. If not, see <http://www.gnu.org/licenses/>. *
* *
*******************************************************************************/
package com.github.shadowsocks.net
import net.sourceforge.jsocks.Socks5Message
import java.io.ByteArrayInputStream
import java.net.*
import java.nio.ByteBuffer
class Socks5DatagramSocket(proxy: Proxy) : DatagramSocket() {
private val proxy = proxy.address() as InetSocketAddress
override fun send(dp: DatagramPacket) {
val data = ByteBuffer.allocate(6 + dp.address.address.size + dp.length).apply {
// header
put(Socks5Message.SOCKS_VERSION.toByte())
putShort(0)
put(when (dp.address) {
is Inet4Address -> Socks5Message.SOCKS_ATYP_IPV4
is Inet6Address -> Socks5Message.SOCKS_ATYP_IPV6
else -> throw IllegalStateException("Unsupported address type")
}.toByte())
put(dp.address.address)
putShort(dp.port.toShort())
// data
put(dp.data, dp.offset, dp.length)
}.array()
super.send(DatagramPacket(data, data.size, proxy.address, proxy.port))
}
override fun receive(dp: DatagramPacket) {
super.receive(dp)
check(proxy.address == dp.address && proxy.port == dp.port) { "Unexpected packet" }
val stream = ByteArrayInputStream(dp.data, dp.offset, dp.length)
val msg = Socks5Message(stream)
dp.port = msg.port
dp.address = msg.inetAddress
val remaining = stream.available()
dp.setData(dp.data, dp.offset + dp.length - remaining, remaining)
}
}
...@@ -24,7 +24,6 @@ android { ...@@ -24,7 +24,6 @@ android {
targetSdkVersion rootProject.sdkVersion targetSdkVersion rootProject.sdkVersion
versionCode rootProject.versionCode versionCode rootProject.versionCode
versionName rootProject.versionName versionName rootProject.versionName
testApplicationId "com.github.shadowsocks.test"
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
resConfigs "fa", "fr", "ja", "ko", "ru", "tr", "zh-rCN", "zh-rTW" resConfigs "fa", "fr", "ja", "ko", "ru", "tr", "zh-rCN", "zh-rTW"
} }
......
...@@ -25,7 +25,6 @@ android { ...@@ -25,7 +25,6 @@ android {
targetSdkVersion rootProject.sdkVersion targetSdkVersion rootProject.sdkVersion
versionCode rootProject.versionCode versionCode rootProject.versionCode
versionName rootProject.versionName versionName rootProject.versionName
testApplicationId "com.github.shadowsocks.tv.test"
resConfigs "fa", "fr", "ja", "ko", "ru", "tr", "zh-rCN", "zh-rTW" resConfigs "fa", "fr", "ja", "ko", "ru", "tr", "zh-rCN", "zh-rTW"
} }
buildTypes { buildTypes {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment