Commit 493b6e85 authored by Mygod's avatar Mygod

Fix register called in wrong thread

parent 0efe2e8a
...@@ -21,29 +21,26 @@ ...@@ -21,29 +21,26 @@
package com.github.shadowsocks.net package com.github.shadowsocks.net
import com.github.shadowsocks.utils.printLog import com.github.shadowsocks.utils.printLog
import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.*
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.launch
import java.io.IOException import java.io.IOException
import java.nio.ByteBuffer import java.nio.ByteBuffer
import java.nio.channels.* import java.nio.channels.*
class ChannelMonitor(private val scope: CoroutineScope) : Thread("ChannelMonitor") { class ChannelMonitor : Thread("ChannelMonitor") {
private data class Registration(val channel: SelectableChannel, private data class Registration(val channel: SelectableChannel,
val ops: Int, val ops: Int,
val listener: suspend (SelectionKey) -> Unit) { val listener: (SelectionKey) -> Unit) {
val result = CompletableDeferred<SelectionKey>() val result = CompletableDeferred<SelectionKey>()
} }
private val selector = Selector.open() private val selector = Selector.open()
private val registrationPipe = Pipe.open() private val registrationPipe = Pipe.open()
private val pendingRegistrations = Channel<Registration>() private val pendingRegistrations = Channel<Registration>(Channel.UNLIMITED)
@Volatile @Volatile
private var running = true private var running = true
private fun registerInternal(channel: SelectableChannel, ops: Int, block: suspend (SelectionKey) -> Unit) = private fun registerInternal(channel: SelectableChannel, ops: Int, block: (SelectionKey) -> Unit) =
channel.register(selector, ops, block) channel.register(selector, ops, block)
init { init {
...@@ -52,7 +49,7 @@ class ChannelMonitor(private val scope: CoroutineScope) : Thread("ChannelMonitor ...@@ -52,7 +49,7 @@ class ChannelMonitor(private val scope: CoroutineScope) : Thread("ChannelMonitor
registerInternal(this, SelectionKey.OP_READ) { registerInternal(this, SelectionKey.OP_READ) {
val junk = ByteBuffer.allocateDirect(1) val junk = ByteBuffer.allocateDirect(1)
while (read(junk) > 0) { while (read(junk) > 0) {
pendingRegistrations.receive().apply { pendingRegistrations.poll()!!.apply {
try { try {
result.complete(registerInternal(channel, ops, listener)) result.complete(registerInternal(channel, ops, listener))
} catch (e: ClosedChannelException) { } catch (e: ClosedChannelException) {
...@@ -66,7 +63,9 @@ class ChannelMonitor(private val scope: CoroutineScope) : Thread("ChannelMonitor ...@@ -66,7 +63,9 @@ class ChannelMonitor(private val scope: CoroutineScope) : Thread("ChannelMonitor
start() start()
} }
suspend fun register(channel: SelectableChannel, ops: Int, block: suspend (SelectionKey) -> Unit): SelectionKey { suspend fun register(channel: SelectableChannel, ops: Int, block: (SelectionKey) -> Unit): SelectionKey {
val registration = Registration(channel, ops, block)
pendingRegistrations.send(registration)
ByteBuffer.allocateDirect(1).also { junk -> ByteBuffer.allocateDirect(1).also { junk ->
loop@ while (running) when (registrationPipe.sink().write(junk)) { loop@ while (running) when (registrationPipe.sink().write(junk)) {
0 -> kotlinx.coroutines.yield() 0 -> kotlinx.coroutines.yield()
...@@ -75,10 +74,7 @@ class ChannelMonitor(private val scope: CoroutineScope) : Thread("ChannelMonitor ...@@ -75,10 +74,7 @@ class ChannelMonitor(private val scope: CoroutineScope) : Thread("ChannelMonitor
} }
} }
if (!running) throw ClosedChannelException() if (!running) throw ClosedChannelException()
return Registration(channel, ops, block).run { return registration.result.await()
pendingRegistrations.send(this)
result.await()
}
} }
suspend fun wait(channel: SelectableChannel, ops: Int) = CompletableDeferred<SelectionKey>().run { suspend fun wait(channel: SelectableChannel, ops: Int) = CompletableDeferred<SelectionKey>().run {
...@@ -100,8 +96,9 @@ class ChannelMonitor(private val scope: CoroutineScope) : Thread("ChannelMonitor ...@@ -100,8 +96,9 @@ class ChannelMonitor(private val scope: CoroutineScope) : Thread("ChannelMonitor
if (num <= 0) continue if (num <= 0) continue
val iterator = selector.selectedKeys().iterator() val iterator = selector.selectedKeys().iterator()
while (iterator.hasNext()) { while (iterator.hasNext()) {
iterator.next().let { scope.launch { (it.attachment() as suspend (SelectionKey) -> Unit)(it) } } val key = iterator.next()
iterator.remove() iterator.remove()
(key.attachment() as (SelectionKey) -> Unit)(key)
} }
} }
} }
......
...@@ -62,7 +62,7 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd ...@@ -62,7 +62,7 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd
private const val TTL = 120L private const val TTL = 120L
private const val UDP_PACKET_SIZE = 512 private const val UDP_PACKET_SIZE = 512
} }
private val monitor = ChannelMonitor(this) private val monitor = ChannelMonitor()
private val job = SupervisorJob() private val job = SupervisorJob()
override val coroutineContext = Dispatchers.Default + job + CoroutineExceptionHandler { _, t -> printLog(t) } override val coroutineContext = Dispatchers.Default + job + CoroutineExceptionHandler { _, t -> printLog(t) }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment