diff --git a/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionAspectSupport.java b/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionAspectSupport.java index 6513dbc6ab6a..1ecef7f73332 100644 --- a/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionAspectSupport.java +++ b/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionAspectSupport.java @@ -352,9 +352,9 @@ protected Object invokeWithinTransaction(Method method, @Nullable Class targe } return new ReactiveTransactionSupport(adapter); }); - Publisher publisher = (Publisher) txSupport.invokeWithinTransaction(method, targetClass, invocation, txAttr, (ReactiveTransactionManager) tm); - return (isSuspendingFunction ? (hasSuspendingFlowReturnType ? KotlinDelegate.asFlow(publisher) : - KotlinDelegate.awaitSingleOrNull(publisher, ((CoroutinesInvocationCallback) invocation).getContinuation())) : publisher); + Object result = txSupport.invokeWithinTransaction(method, targetClass, invocation, txAttr, (ReactiveTransactionManager) tm); + return (isSuspendingFunction ? (hasSuspendingFlowReturnType ? KotlinDelegate.asFlow((Publisher) result) : + KotlinDelegate.awaitSingleOrNull((Publisher) result, ((CoroutinesInvocationCallback) invocation).getContinuation())) : result); } PlatformTransactionManager ptm = asPlatformTransactionManager(tm); diff --git a/spring-tx/src/test/kotlin/org/springframework/transaction/annotation/CoroutinesAnnotationTransactionInterceptorTests.kt b/spring-tx/src/test/kotlin/org/springframework/transaction/annotation/CoroutinesAnnotationTransactionInterceptorTests.kt index 76dde1c9fd3a..fa823eb79f0a 100644 --- a/spring-tx/src/test/kotlin/org/springframework/transaction/annotation/CoroutinesAnnotationTransactionInterceptorTests.kt +++ b/spring-tx/src/test/kotlin/org/springframework/transaction/annotation/CoroutinesAnnotationTransactionInterceptorTests.kt @@ -17,6 +17,9 @@ package org.springframework.transaction.annotation import kotlinx.coroutines.delay +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.toList import kotlinx.coroutines.runBlocking import org.assertj.core.api.Assertions import org.junit.jupiter.api.Test @@ -83,14 +86,38 @@ class CoroutinesAnnotationTransactionInterceptorTests { runBlocking { try { proxy.suspendingValueFailure() + Assertions.fail("No exception thrown as expected") } catch (ex: IllegalStateException) { } - } assertReactiveGetTransactionAndRollbackCount(1) } + @Test + fun suspendingFlowSuccess() { + val proxyFactory = ProxyFactory() + proxyFactory.setTarget(TestWithCoroutines()) + proxyFactory.addAdvice(TransactionInterceptor(rtm, source)) + val proxy = proxyFactory.proxy as TestWithCoroutines + runBlocking { + Assertions.assertThat(proxy.suspendingFlowSuccess().toList()).containsExactly("foo", "foo") + } + assertReactiveGetTransactionAndCommitCount(1) + } + + @Test + fun flowSuccess() { + val proxyFactory = ProxyFactory() + proxyFactory.setTarget(TestWithCoroutines()) + proxyFactory.addAdvice(TransactionInterceptor(rtm, source)) + val proxy = proxyFactory.proxy as TestWithCoroutines + runBlocking { + Assertions.assertThat(proxy.flowSuccess().toList()).containsExactly("foo", "foo") + } + assertReactiveGetTransactionAndCommitCount(1) + } + private fun assertReactiveGetTransactionAndCommitCount(expectedCount: Int) { Assertions.assertThat(rtm.begun).isEqualTo(expectedCount) Assertions.assertThat(rtm.commits).isEqualTo(expectedCount) @@ -122,5 +149,22 @@ class CoroutinesAnnotationTransactionInterceptorTests { delay(10) throw IllegalStateException() } + + open fun flowSuccess(): Flow { + return flow { + emit("foo") + delay(10) + emit("foo") + } + } + + open suspend fun suspendingFlowSuccess(): Flow { + delay(10) + return flow { + emit("foo") + delay(10) + emit("foo") + } + } } }