Commit 490e0432 authored by Mygod's avatar Mygod

Fix ClosedChannelException

parent 00dfb925
...@@ -310,7 +310,7 @@ object BaseService { ...@@ -310,7 +310,7 @@ object BaseService {
data.changeState(CONNECTING) data.changeState(CONNECTING)
data.connectingJob = GlobalScope.launch(Dispatchers.Main) { data.connectingJob = GlobalScope.launch(Dispatchers.Main) {
try { try {
killProcesses() Executable.killAll() // clean up old processes
preInit() preInit()
proxy.init(this@Interface::resolver) proxy.init(this@Interface::resolver)
data.udpFallback?.init(this@Interface::resolver) data.udpFallback?.init(this@Interface::resolver)
......
...@@ -21,22 +21,28 @@ ...@@ -21,22 +21,28 @@
package com.github.shadowsocks.net package com.github.shadowsocks.net
import com.github.shadowsocks.utils.printLog import com.github.shadowsocks.utils.printLog
import kotlinx.coroutines.suspendCancellableCoroutine import kotlinx.coroutines.*
import kotlinx.coroutines.channels.Channel
import java.io.IOException import java.io.IOException
import java.lang.IllegalStateException
import java.nio.ByteBuffer import java.nio.ByteBuffer
import java.nio.channels.* import java.nio.channels.*
import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.Executors
import kotlin.coroutines.resume
class ChannelMonitor : Thread("ChannelMonitor"), AutoCloseable { class ChannelMonitor {
private data class Registration(val channel: SelectableChannel,
val ops: Int,
val listener: suspend (SelectionKey) -> Unit) {
val result = CompletableDeferred<SelectionKey>()
}
private val job: Job
private val selector = Selector.open() private val selector = Selector.open()
private val registrationPipe = Pipe.open() private val registrationPipe = Pipe.open()
private val pendingRegistrations = ConcurrentLinkedQueue<Triple<SelectableChannel, Int, (SelectionKey) -> Unit>>() private val pendingRegistrations = Channel<Registration>()
@Volatile @Volatile
private var running = true private var running = true
private fun registerInternal(channel: SelectableChannel, ops: Int, block: (SelectionKey) -> Unit) = private fun registerInternal(channel: SelectableChannel, ops: Int, block: suspend (SelectionKey) -> Unit) =
channel.register(selector, ops, block) channel.register(selector, ops, block)
init { init {
...@@ -45,52 +51,63 @@ class ChannelMonitor : Thread("ChannelMonitor"), AutoCloseable { ...@@ -45,52 +51,63 @@ class ChannelMonitor : Thread("ChannelMonitor"), AutoCloseable {
registerInternal(this, SelectionKey.OP_READ) { registerInternal(this, SelectionKey.OP_READ) {
val junk = ByteBuffer.allocateDirect(1) val junk = ByteBuffer.allocateDirect(1)
while (read(junk) > 0) { while (read(junk) > 0) {
val (channel, ops, block) = pendingRegistrations.remove() pendingRegistrations.receive().apply {
registerInternal(channel, ops, block) try {
result.complete(registerInternal(channel, ops, listener))
} catch (e: ClosedChannelException) {
result.completeExceptionally(e)
}
}
junk.clear() junk.clear()
} }
} }
} }
start() job = GlobalScope.launch(Executors.newSingleThreadExecutor().asCoroutineDispatcher()) {
while (running) {
val num = try {
selector.select()
} catch (e: IOException) {
printLog(e)
continue
}
if (num <= 0) continue
val iterator = selector.selectedKeys().iterator()
while (iterator.hasNext()) {
val key = iterator.next()
iterator.remove()
(key.attachment() as suspend (SelectionKey) -> Unit)(key)
}
}
}
} }
fun register(channel: SelectableChannel, ops: Int, block: (SelectionKey) -> Unit) { suspend fun register(channel: SelectableChannel, ops: Int, block: suspend (SelectionKey) -> Unit): SelectionKey {
pendingRegistrations.add(Triple(channel, ops, block)) ByteBuffer.allocateDirect(1).also { junk ->
val junk = ByteBuffer.allocateDirect(1) loop@ while (running) when (registrationPipe.sink().write(junk)) {
while (running && registrationPipe.sink().write(junk) == 0); 0 -> yield()
1 -> break@loop
else -> throw IOException("Failed to register in the channel")
}
}
if (!running) throw ClosedChannelException()
return Registration(channel, ops, block).run {
pendingRegistrations.send(this)
result.await()
}
} }
suspend fun wait(channel: SelectableChannel, ops: Int) = suspendCancellableCoroutine<Unit> { cont -> suspend fun wait(channel: SelectableChannel, ops: Int) = CompletableDeferred<SelectionKey>().run {
register(channel, ops) { register(channel, ops) {
if (it.isValid) it.interestOps(0) // stop listening if (it.isValid) it.interestOps(0) // stop listening
try { complete(it)
cont.resume(Unit)
} catch (_: IllegalStateException) { } // already resumed by a timeout, maybe should use tryResume?
}
}
override fun run() {
while (running) {
val num = try {
selector.select()
} catch (e: IOException) {
printLog(e)
continue
}
if (num <= 0) continue
val iterator = selector.selectedKeys().iterator()
while (iterator.hasNext()) {
val key = iterator.next()
iterator.remove()
(key.attachment() as (SelectionKey) -> Unit)(key)
}
} }
await()
} }
override fun close() { suspend fun close() {
running = false running = false
selector.wakeup() selector.wakeup()
join() job.join()
selector.keys().forEach { it.channel().close() } selector.keys().forEach { it.channel().close() }
selector.close() selector.close()
} }
......
...@@ -67,14 +67,14 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd ...@@ -67,14 +67,14 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd
private val job = SupervisorJob() private val job = SupervisorJob()
override val coroutineContext = Dispatchers.Default + job + CoroutineExceptionHandler { _, t -> printLog(t) } override val coroutineContext = Dispatchers.Default + job + CoroutineExceptionHandler { _, t -> printLog(t) }
fun start(listen: SocketAddress) = DatagramChannel.open().apply { suspend fun start(listen: SocketAddress) = DatagramChannel.open().apply {
configureBlocking(false) configureBlocking(false)
socket().bind(listen) socket().bind(listen)
monitor.register(this, SelectionKey.OP_READ) { monitor.register(this, SelectionKey.OP_READ) {
val buffer = ByteBuffer.allocateDirect(UDP_PACKET_SIZE) val buffer = ByteBuffer.allocateDirect(UDP_PACKET_SIZE)
val source = receive(buffer)!! val source = receive(buffer)!!
buffer.flip() buffer.flip()
launch { this@LocalDnsServer.launch {
val reply = resolve(buffer) val reply = resolve(buffer)
while (send(reply, source) <= 0) monitor.wait(this@apply, SelectionKey.OP_WRITE) while (send(reply, source) <= 0) monitor.wait(this@apply, SelectionKey.OP_WRITE)
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment