Commit 0efe2e8a authored by Mygod's avatar Mygod

Revert using a thread pool

parent 0d5834eb
...@@ -21,21 +21,22 @@ ...@@ -21,21 +21,22 @@
package com.github.shadowsocks.net package com.github.shadowsocks.net
import com.github.shadowsocks.utils.printLog import com.github.shadowsocks.utils.printLog
import kotlinx.coroutines.* import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.launch
import java.io.IOException import java.io.IOException
import java.nio.ByteBuffer import java.nio.ByteBuffer
import java.nio.channels.* import java.nio.channels.*
import java.util.concurrent.Executors
class ChannelMonitor { class ChannelMonitor(private val scope: CoroutineScope) : Thread("ChannelMonitor") {
private data class Registration(val channel: SelectableChannel, private data class Registration(val channel: SelectableChannel,
val ops: Int, val ops: Int,
val listener: suspend (SelectionKey) -> Unit) { val listener: suspend (SelectionKey) -> Unit) {
val result = CompletableDeferred<SelectionKey>() val result = CompletableDeferred<SelectionKey>()
} }
private val job: Job
private val selector = Selector.open() private val selector = Selector.open()
private val registrationPipe = Pipe.open() private val registrationPipe = Pipe.open()
private val pendingRegistrations = Channel<Registration>() private val pendingRegistrations = Channel<Registration>()
...@@ -62,29 +63,13 @@ class ChannelMonitor { ...@@ -62,29 +63,13 @@ class ChannelMonitor {
} }
} }
} }
job = GlobalScope.launch(Executors.newSingleThreadExecutor().asCoroutineDispatcher()) { start()
while (running) {
val num = try {
selector.select()
} catch (e: IOException) {
printLog(e)
continue
}
if (num <= 0) continue
val iterator = selector.selectedKeys().iterator()
while (iterator.hasNext()) {
val key = iterator.next()
iterator.remove()
(key.attachment() as suspend (SelectionKey) -> Unit)(key)
}
}
}
} }
suspend fun register(channel: SelectableChannel, ops: Int, block: suspend (SelectionKey) -> Unit): SelectionKey { suspend fun register(channel: SelectableChannel, ops: Int, block: suspend (SelectionKey) -> Unit): SelectionKey {
ByteBuffer.allocateDirect(1).also { junk -> ByteBuffer.allocateDirect(1).also { junk ->
loop@ while (running) when (registrationPipe.sink().write(junk)) { loop@ while (running) when (registrationPipe.sink().write(junk)) {
0 -> yield() 0 -> kotlinx.coroutines.yield()
1 -> break@loop 1 -> break@loop
else -> throw IOException("Failed to register in the channel") else -> throw IOException("Failed to register in the channel")
} }
...@@ -104,11 +89,28 @@ class ChannelMonitor { ...@@ -104,11 +89,28 @@ class ChannelMonitor {
await() await()
} }
override fun run() {
while (running) {
val num = try {
selector.select()
} catch (e: IOException) {
printLog(e)
continue
}
if (num <= 0) continue
val iterator = selector.selectedKeys().iterator()
while (iterator.hasNext()) {
iterator.next().let { scope.launch { (it.attachment() as suspend (SelectionKey) -> Unit)(it) } }
iterator.remove()
}
}
}
fun close(scope: CoroutineScope) { fun close(scope: CoroutineScope) {
running = false running = false
selector.wakeup() selector.wakeup()
scope.launch { scope.launch(Dispatchers.IO) { // thread joining is a blocking operation
job.join() join()
selector.keys().forEach { it.channel().close() } selector.keys().forEach { it.channel().close() }
selector.close() selector.close()
} }
......
...@@ -62,7 +62,7 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd ...@@ -62,7 +62,7 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd
private const val TTL = 120L private const val TTL = 120L
private const val UDP_PACKET_SIZE = 512 private const val UDP_PACKET_SIZE = 512
} }
private val monitor = ChannelMonitor() private val monitor = ChannelMonitor(this)
private val job = SupervisorJob() private val job = SupervisorJob()
override val coroutineContext = Dispatchers.Default + job + CoroutineExceptionHandler { _, t -> printLog(t) } override val coroutineContext = Dispatchers.Default + job + CoroutineExceptionHandler { _, t -> printLog(t) }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment