Commit b6880efd authored by Mygod's avatar Mygod

Protect sockets concurrently

parent 075225db
......@@ -45,7 +45,7 @@ abstract class LocalSocketListener(name: String, socketFile: File) : Thread(name
final override fun run() = localSocket.use {
while (running) {
try {
serverSocket.accept().use { accept(it) }
accept(serverSocket.accept())
} catch (e: IOException) {
if (running) printLog(e)
continue
......
......@@ -30,7 +30,7 @@ import java.nio.ByteOrder
class TrafficMonitor(statFile: File) : AutoCloseable {
private val thread = object : LocalSocketListener("TrafficMonitor", statFile) {
override fun accept(socket: LocalSocket) {
override fun accept(socket: LocalSocket) = socket.use {
val buffer = ByteArray(16)
if (socket.inputStream.read(buffer) != 16) throw IOException("Unexpected traffic stat length")
val stat = ByteBuffer.wrap(buffer).order(ByteOrder.LITTLE_ENDIAN)
......
......@@ -39,6 +39,7 @@ import com.github.shadowsocks.utils.Key
import com.github.shadowsocks.utils.Subnet
import com.github.shadowsocks.utils.parseNumericAddress
import com.github.shadowsocks.utils.printLog
import kotlinx.coroutines.*
import java.io.Closeable
import java.io.File
import java.io.FileDescriptor
......@@ -79,23 +80,37 @@ class VpnService : BaseVpnService(), LocalDnsService.Interface {
}
private inner class ProtectWorker :
LocalSocketListener("ShadowsocksVpnThread", File(Core.deviceStorage.noBackupFilesDir, "protect_path")) {
LocalSocketListener("ShadowsocksVpnThread", File(Core.deviceStorage.noBackupFilesDir, "protect_path")),
CoroutineScope {
private val job = SupervisorJob()
override val coroutineContext get() = Dispatchers.IO + job + CoroutineExceptionHandler { _, t -> printLog(t) }
override fun accept(socket: LocalSocket) {
socket.inputStream.read()
val fd = socket.ancillaryFileDescriptors!!.single()!!
CloseableFd(fd).use {
socket.outputStream.write(if (underlyingNetwork.let { network ->
if (network != null && Build.VERSION.SDK_INT >= 23) try {
network.bindSocket(fd)
true
} catch (e: IOException) {
// suppress ENONET (Machine is not on the network)
if ((e.cause as? ErrnoException)?.errno != 64) printLog(e)
false
} else protect(getInt.invoke(fd) as Int)
}) 0 else 1)
launch {
socket.use {
socket.inputStream.read()
val fd = socket.ancillaryFileDescriptors!!.single()!!
CloseableFd(fd).use {
socket.outputStream.write(if (underlyingNetwork.let { network ->
if (network != null && Build.VERSION.SDK_INT >= 23) try {
network.bindSocket(fd)
true
} catch (e: IOException) {
// suppress ENONET (Machine is not on the network)
if ((e.cause as? ErrnoException)?.errno != 64) printLog(e)
false
} else protect(getInt.invoke(fd) as Int)
}) 0 else 1)
}
}
}
}
suspend fun shutdown() {
job.cancel()
close()
job.join()
}
}
inner class NullConnectionException : NullPointerException() {
override fun getLocalizedMessage() = getString(R.string.reboot_required)
......@@ -143,7 +158,7 @@ class VpnService : BaseVpnService(), LocalDnsService.Interface {
connectivity.unregisterNetworkCallback(defaultNetworkCallback)
listeningForDefaultNetwork = false
}
worker?.close()
worker?.shutdown()
worker = null
super.killProcesses()
conn?.close()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment