Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix leaks on reconnection #123

Merged
merged 1 commit into from
Dec 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package io.rsocket.kotlin.core

import io.ktor.utils.io.core.*
import io.rsocket.kotlin.*
import io.rsocket.kotlin.internal.*
import io.rsocket.kotlin.logging.*
import io.rsocket.kotlin.payload.*
import kotlinx.coroutines.*
Expand Down Expand Up @@ -97,30 +98,33 @@ private class ReconnectableRSocket(
private val state: StateFlow<ReconnectState>,
) : RSocket {

private val reconnectHandler = state.mapNotNull { it.handleState { null } }.take(1)
private val reconnectHandler = state.mapNotNull { it.current() }.take(1)

//null pointer will never happen
private suspend fun currentRSocket(): RSocket = state.value.handleState { reconnectHandler.first() }!!
private suspend fun currentRSocket(closeable: Closeable): RSocket = closeable.closeOnError { currentRSocket() }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense after digging into the code, but the closeOnError confused me for a bit. I think wildcard imports don't help.


private inline fun ReconnectState.handleState(onReconnect: () -> RSocket?): RSocket? = when (this) {
is ReconnectState.Connected -> when {
rSocket.isActive -> rSocket //connection is ready to handle requests
else -> onReconnect() //reconnection
}
private suspend fun currentRSocket(): RSocket = state.value.current() ?: reconnectHandler.first()

private fun ReconnectState.current(): RSocket? = when (this) {
is ReconnectState.Connected -> rSocket.takeIf(RSocket::isActive) //connection is ready to handle requests
is ReconnectState.Failed -> throw error //connection failed - fail requests
ReconnectState.Connecting -> onReconnect() //reconnection
ReconnectState.Connecting -> null //reconnection
}

private suspend inline fun <T : Any> execSuspend(operation: RSocket.() -> T): T =
currentRSocket().operation()
override suspend fun metadataPush(metadata: ByteReadPacket): Unit =
currentRSocket(metadata).metadataPush(metadata)

override suspend fun fireAndForget(payload: Payload): Unit =
currentRSocket(payload).fireAndForget(payload)

private inline fun execFlow(crossinline operation: RSocket.() -> Flow<Payload>): Flow<Payload> =
flow { emitAll(currentRSocket().operation()) }
override suspend fun requestResponse(payload: Payload): Payload =
currentRSocket(payload).requestResponse(payload)

override suspend fun metadataPush(metadata: ByteReadPacket): Unit = execSuspend { metadataPush(metadata) }
override suspend fun fireAndForget(payload: Payload): Unit = execSuspend { fireAndForget(payload) }
override suspend fun requestResponse(payload: Payload): Payload = execSuspend { requestResponse(payload) }
override fun requestStream(payload: Payload): Flow<Payload> = execFlow { requestStream(payload) }
override fun requestChannel(payloads: Flow<Payload>): Flow<Payload> = execFlow { requestChannel(payloads) }
override fun requestStream(payload: Payload): Flow<Payload> = flow {
emitAll(currentRSocket(payload).requestStream(payload))
}

override fun requestChannel(payloads: Flow<Payload>): Flow<Payload> = flow {
emitAll(currentRSocket().requestChannel(payloads))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package io.rsocket.kotlin.core

import app.cash.turbine.*
import io.ktor.utils.io.core.*
import io.rsocket.kotlin.*
import io.rsocket.kotlin.logging.*
import io.rsocket.kotlin.payload.*
Expand Down Expand Up @@ -54,7 +55,7 @@ class ReconnectableRSocketTest : SuspendTest, TestWithLeakCheck {
val connect: suspend () -> RSocket = {
if (first.value) {
first.value = false
rrHandler(firstJob)
handler(firstJob)
} else {
error("Failed to connect")
}
Expand Down Expand Up @@ -89,7 +90,7 @@ class ReconnectableRSocketTest : SuspendTest, TestWithLeakCheck {
first.value = false
error("Failed to connect")
} else {
rrHandler(handlerJob)
handler(handlerJob)
}
}
val rSocket = ReconnectableRSocket(logger, connect) { cause, attempt ->
Expand All @@ -114,7 +115,7 @@ class ReconnectableRSocketTest : SuspendTest, TestWithLeakCheck {
error("Failed to connect")
} else {
delay(200) //emulate connection establishment
rrHandler(Job())
handler(Job())
}
}
val rSocket = ReconnectableRSocket(logger, connect) { cause, attempt ->
Expand All @@ -137,13 +138,13 @@ class ReconnectableRSocketTest : SuspendTest, TestWithLeakCheck {
when {
first.value -> {
first.value = false
rrHandler(firstJob) //first connection
handler(firstJob) //first connection
}
fails.value < 5 -> {
delay(100)
error("Failed to connect")
}
else -> rrHandler(Job())
else -> handler(Job())
}
}
val rSocket = ReconnectableRSocket(logger, connect) { cause, attempt ->
Expand All @@ -170,13 +171,13 @@ class ReconnectableRSocketTest : SuspendTest, TestWithLeakCheck {
when {
first.value -> {
first.value = false
streamHandler(firstJob) //first connection
handler(firstJob) //first connection
}
fails.value < 5 -> {
delay(100)
error("Failed to connect")
}
else -> streamHandler(Job())
else -> handler(Job())
}
}
val rSocket = ReconnectableRSocket(logger, connect) { cause, attempt ->
Expand Down Expand Up @@ -206,8 +207,52 @@ class ReconnectableRSocketTest : SuspendTest, TestWithLeakCheck {
assertEquals(5, fails.value)
}

private fun rrHandler(job: Job): RSocket = RSocketRequestHandler(job) { requestResponse { it } }
private fun streamHandler(job: Job): RSocket = RSocketRequestHandler(job) {
@Test
fun testNoLeakMetadataPush() = testNoLeaksInteraction { metadataPush(it.data) }

@Test
fun testNoLeakFireAndForget() = testNoLeaksInteraction { fireAndForget(it) }

@Test
fun testNoLeakRequestResponse() = testNoLeaksInteraction { requestResponse(it) }

@Test
fun testNoLeakRequestStream() = testNoLeaksInteraction { requestStream(it).collect() }

private inline fun testNoLeaksInteraction(crossinline interaction: suspend RSocket.(payload: Payload) -> Unit) = test {
val firstJob = Job()
val connect: suspend () -> RSocket = {
if (first.compareAndSet(true, false)) {
handler(firstJob)
} else {
error("Failed to connect")
}
}
val rSocket = ReconnectableRSocket(logger, connect) { _, attempt ->
delay(100)
attempt < 5
}

rSocket.requestResponse(Payload.Empty) //first request to be sure, that connected
firstJob.cancelAndJoin() //cancel

val p = payload("text")
assertFails {
rSocket.interaction(p) //test release on reconnecting
}
assertTrue(p.data.isEmpty)

val p2 = payload("text")
assertFails {
rSocket.interaction(p2) //test release on failed
}
assertTrue(p2.data.isEmpty)
}

private fun handler(job: Job): RSocket = RSocketRequestHandler(job) {
requestResponse { payload ->
payload
}
requestStream {
flow {
repeat(5) {
Expand Down