Skip to content

Commit

Permalink
Navigate back to AddFirstPaymentMethod immediately after removing las…
Browse files Browse the repository at this point in the history
…t PM (#10107)
  • Loading branch information
amk-stripe authored Feb 7, 2025
1 parent 268f74c commit f6711d1
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 109 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ internal class DefaultEmbeddedContentHelper @Inject constructor(
setSelection(null)
},
customerStateHolder = customerStateHolder,
onPaymentMethodRemoved = {
},
prePaymentMethodRemoveActions = {},
postPaymentMethodRemoveActions = {},
onUpdatePaymentMethod = { _, _, _, _ ->
sheetLauncher?.launchManage(
paymentMethodMetadata = paymentMethodMetadata,
Expand All @@ -230,7 +230,6 @@ internal class DefaultEmbeddedContentHelper @Inject constructor(
},
isLinkEnabled = stateFlowOf(paymentMethodMetadata.linkState != null),
isNotPaymentFlow = false,
isEmbedded = true,
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ internal class ManageSavedPaymentMethodMutatorFactory @Inject constructor(
selectionHolder.set(null)
},
customerStateHolder = customerStateHolder,
onPaymentMethodRemoved = ::onPaymentMethodRemoved,
prePaymentMethodRemoveActions = {},
postPaymentMethodRemoveActions = ::onPaymentMethodRemoved,
onUpdatePaymentMethod = { displayableSavedPaymentMethod, _, _, _ ->
onUpdatePaymentMethod(displayableSavedPaymentMethod)
},
Expand All @@ -48,7 +49,6 @@ internal class ManageSavedPaymentMethodMutatorFactory @Inject constructor(
},
isLinkEnabled = stateFlowOf(false), // Link is never enabled in the manage screen.
isNotPaymentFlow = false,
isEmbedded = true,
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,15 @@ internal class SavedPaymentMethodMutator(
private val selection: StateFlow<PaymentSelection?>,
private val clearSelection: () -> Unit,
private val customerStateHolder: CustomerStateHolder,
private val onPaymentMethodRemoved: () -> Unit,
// Actions that should be taken after removing a payment method has succeeded but before we've fully updated our
// state to reflect that. For example, in our manage payment method screen, we want to navigate back to the
// saved payment methods list before removing the payment method from our state, so that users can see the removed
// payment method get animated out.
private val prePaymentMethodRemoveActions: suspend () -> Unit,
// Actions that should be taken after removing a payment method has succeeded and after our state has been updated.
// For example, closing the embedded manage saved payment methods screen after the final saved payment method is
// removed.
private val postPaymentMethodRemoveActions: () -> Unit,
private val onUpdatePaymentMethod: (
DisplayableSavedPaymentMethod,
canRemove: Boolean,
Expand All @@ -48,7 +56,6 @@ internal class SavedPaymentMethodMutator(
private val navigationPop: () -> Unit,
isLinkEnabled: StateFlow<Boolean?>,
isNotPaymentFlow: Boolean,
private val isEmbedded: Boolean,
) {
val defaultPaymentMethodId: StateFlow<String?> = customerStateHolder.customer.mapAsStateFlow { customerState ->
when (val defaultPaymentMethodState = customerState?.defaultPaymentMethodState) {
Expand Down Expand Up @@ -171,7 +178,7 @@ internal class SavedPaymentMethodMutator(
clearSelection()
}

onPaymentMethodRemoved()
postPaymentMethodRemoveActions()
}

fun updatePaymentMethod(displayableSavedPaymentMethod: DisplayableSavedPaymentMethod) {
Expand All @@ -194,10 +201,7 @@ internal class SavedPaymentMethodMutator(

if (result.isSuccess) {
coroutineScope.launch(workContext) {
if (!isEmbedded || customerStateHolder.paymentMethods.value.size > 1) {
navigationPop()
delay(PaymentMethodRemovalDelayMillis)
}
prePaymentMethodRemoveActions()
removeDeletedPaymentMethodFromState(paymentMethodId = paymentMethodId)
}
}
Expand Down Expand Up @@ -261,19 +265,40 @@ internal class SavedPaymentMethodMutator(
}

companion object {
private fun onPaymentMethodRemoved(viewModel: BaseSheetViewModel) {
val currentScreen = viewModel.navigationHandler.currentScreen.value
val shouldResetToAddPaymentMethodForm =
viewModel.customerStateHolder.paymentMethods.value.isEmpty() &&
currentScreen is PaymentSheetScreen.SelectSavedPaymentMethods

if (shouldResetToAddPaymentMethodForm) {
val interactor = DefaultAddPaymentMethodInteractor.create(
viewModel = viewModel,
paymentMethodMetadata = requireNotNull(viewModel.paymentMethodMetadata.value),
)
val screen = PaymentSheetScreen.AddFirstPaymentMethod(interactor)
viewModel.navigationHandler.resetTo(listOf(screen))
private suspend fun popWithDelay(viewModel: BaseSheetViewModel) {
viewModel.navigationHandler.pop()
delay(PaymentMethodRemovalDelayMillis)
}

private suspend fun navigateBackOnPaymentMethodRemoved(viewModel: BaseSheetViewModel) {
val previousScreen = viewModel.navigationHandler.previousScreen.value

when (previousScreen) {
is PaymentSheetScreen.SelectSavedPaymentMethods -> {
if (viewModel.customerStateHolder.paymentMethods.value.size == 1) {
// If we're removing the last payment method in horizontal mode, we want to transition
// immediately to the AddFirstPaymentMethod screen.
val interactor = DefaultAddPaymentMethodInteractor.create(
viewModel = viewModel,
paymentMethodMetadata = requireNotNull(viewModel.paymentMethodMetadata.value),
)
val screen = PaymentSheetScreen.AddFirstPaymentMethod(interactor)
viewModel.navigationHandler.resetTo(listOf(screen))
} else {
popWithDelay(viewModel)
}
}
is PaymentSheetScreen.ManageSavedPaymentMethods,
is PaymentSheetScreen.VerticalMode -> popWithDelay(viewModel)
is PaymentSheetScreen.AddAnotherPaymentMethod,
is PaymentSheetScreen.AddFirstPaymentMethod,
is PaymentSheetScreen.CvcRecollection,
PaymentSheetScreen.Loading,
is PaymentSheetScreen.UpdatePaymentMethod,
is PaymentSheetScreen.VerticalModeForm,
null -> {
// We don't allow navigating to the payment method remove screen from these screens.
}
}
}

Expand Down Expand Up @@ -327,9 +352,10 @@ internal class SavedPaymentMethodMutator(
selection = viewModel.selection,
customerStateHolder = viewModel.customerStateHolder,
clearSelection = { viewModel.updateSelection(null) },
onPaymentMethodRemoved = {
onPaymentMethodRemoved(viewModel)
prePaymentMethodRemoveActions = {
navigateBackOnPaymentMethodRemoved(viewModel)
},
postPaymentMethodRemoveActions = {},
onUpdatePaymentMethod = { displayableSavedPaymentMethod, canRemove, performRemove, updateExecutor ->
onUpdatePaymentMethod(
viewModel = viewModel,
Expand All @@ -342,7 +368,6 @@ internal class SavedPaymentMethodMutator(
navigationPop = viewModel.navigationHandler::pop,
isLinkEnabled = viewModel.linkHandler.isLinkEnabled,
isNotPaymentFlow = !viewModel.isCompleteFlow,
isEmbedded = false,
).apply {
viewModel.viewModelScope.launch {
viewModel.navigationHandler.currentScreen.collect { currentScreen ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ internal class NavigationHandler<T : Any>(
val currentScreen: StateFlow<T> = backStack
.mapAsStateFlow { it.last() }

val previousScreen: StateFlow<T?> = backStack.mapAsStateFlow {
if (it.isEmpty() || it.size == 1) {
// In these cases, there is no "previous screen".
null
} else {
it[it.size - 2]
}
}

val canGoBack: Boolean
get() = backStack.value.size > 1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class SavedPaymentMethodMutatorTest {
assertThat(awaitItem()).isEmpty()
}

assertThat(paymentMethodRemovedTurbine.awaitItem()).isEqualTo(Unit)
assertThat(postPaymentMethodRemovedTurbine.awaitItem()).isEqualTo(Unit)

assertThat(calledDetach).isTrue()
}
Expand Down Expand Up @@ -193,7 +193,7 @@ class SavedPaymentMethodMutatorTest {
assertThat(awaitItem()).isFalse()
}

assertThat(paymentMethodRemovedTurbine.awaitItem()).isEqualTo(Unit)
assertThat(postPaymentMethodRemovedTurbine.awaitItem()).isEqualTo(Unit)
}

@Test
Expand Down Expand Up @@ -251,7 +251,7 @@ class SavedPaymentMethodMutatorTest {
assertThat(awaitItem()).isNull()
}

assertThat(paymentMethodRemovedTurbine.awaitItem()).isEqualTo(Unit)
assertThat(postPaymentMethodRemovedTurbine.awaitItem()).isEqualTo(Unit)
}

@Test
Expand Down Expand Up @@ -303,8 +303,8 @@ class SavedPaymentMethodMutatorTest {
updatePaymentMethodTurbine.awaitItem().performRemove()

assertThat(calledDetach.awaitItem()).isTrue()
assertThat(navigationPopTurbine.awaitItem()).isNotNull()
assertThat(paymentMethodRemovedTurbine.awaitItem()).isNotNull()
assertThat(prePaymentMethodRemovedTurbine.awaitItem()).isNotNull()
assertThat(postPaymentMethodRemovedTurbine.awaitItem()).isNotNull()

assertThat(customerStateHolder.paymentMethods.value).isEmpty()
}
Expand All @@ -313,7 +313,7 @@ class SavedPaymentMethodMutatorTest {
}

@Test
fun `removePaymentMethodInEditScreen calls pop when not in embedded`() {
fun `removePaymentMethodInEditScreen calls prePaymentMethodRemoveActions and postPaymentMethodRemoveActions`() {
val displayableSavedPaymentMethod = PaymentMethodFactory.cards(1).first().toDisplayableSavedPaymentMethod()
val calledDetach = Turbine<Boolean>()
val customerRepository = FakeCustomerRepository(
Expand All @@ -336,80 +336,15 @@ class SavedPaymentMethodMutatorTest {
savedPaymentMethodMutator.removePaymentMethodInEditScreen(displayableSavedPaymentMethod.paymentMethod)

assertThat(calledDetach.awaitItem()).isTrue()
assertThat(navigationPopTurbine.awaitItem()).isNotNull()
assertThat(paymentMethodRemovedTurbine.awaitItem()).isNotNull()
assertThat(prePaymentMethodRemovedTurbine.awaitItem()).isNotNull()
assertThat(postPaymentMethodRemovedTurbine.awaitItem()).isNotNull()

assertThat(customerStateHolder.paymentMethods.value).isEmpty()
}

calledDetach.ensureAllEventsConsumed()
}

@Test
fun `removePaymentMethodInEditScreen does not call pop when in embedded`() {
val displayableSavedPaymentMethod = PaymentMethodFactory.cards(1).first().toDisplayableSavedPaymentMethod()
val calledDetach = Turbine<Boolean>()
val customerRepository = FakeCustomerRepository(
onDetachPaymentMethod = { paymentMethodId ->
assertThat(paymentMethodId).isEqualTo(displayableSavedPaymentMethod.paymentMethod.id!!)
calledDetach.add(true)
Result.success(displayableSavedPaymentMethod.paymentMethod)
}
)

runScenario(customerRepository = customerRepository, isEmbedded = true) {
customerStateHolder.setCustomerState(
createCustomerState(
paymentMethods = listOf(displayableSavedPaymentMethod.paymentMethod),
isRemoveEnabled = true,
canRemoveLastPaymentMethod = true,
)
)

savedPaymentMethodMutator.removePaymentMethodInEditScreen(displayableSavedPaymentMethod.paymentMethod)

assertThat(calledDetach.awaitItem()).isTrue()
assertThat(paymentMethodRemovedTurbine.awaitItem()).isNotNull()

assertThat(customerStateHolder.paymentMethods.value).isEmpty()
}

calledDetach.ensureAllEventsConsumed()
}

@Test
fun `removePaymentMethodInEditScreen calls pop when in embedded with multiple cards`() {
val displayableSavedPaymentMethod = PaymentMethodFactory.cards(1).first().toDisplayableSavedPaymentMethod()
val calledDetach = Turbine<Boolean>()
val customerRepository = FakeCustomerRepository(
onDetachPaymentMethod = { paymentMethodId ->
assertThat(paymentMethodId).isEqualTo(displayableSavedPaymentMethod.paymentMethod.id!!)
calledDetach.add(true)
Result.success(displayableSavedPaymentMethod.paymentMethod)
}
)

runScenario(customerRepository = customerRepository, isEmbedded = true) {
customerStateHolder.setCustomerState(
createCustomerState(
paymentMethods = PaymentMethodFactory.cards(2) + displayableSavedPaymentMethod.paymentMethod,
isRemoveEnabled = true,
canRemoveLastPaymentMethod = true,
)
)

savedPaymentMethodMutator.removePaymentMethodInEditScreen(displayableSavedPaymentMethod.paymentMethod)

assertThat(calledDetach.awaitItem()).isTrue()
assertThat(paymentMethodRemovedTurbine.awaitItem()).isNotNull()
assertThat(navigationPopTurbine.awaitItem()).isNotNull()

assertThat(customerStateHolder.paymentMethods.value).hasSize(2)
}

calledDetach.ensureAllEventsConsumed()
}

@Test
fun `updatePaymentMethod performRemove failure callback`() {
val displayableSavedPaymentMethod = PaymentMethodFactory.cards(1).first().toDisplayableSavedPaymentMethod()
Expand Down Expand Up @@ -583,7 +518,7 @@ class SavedPaymentMethodMutatorTest {

savedPaymentMethodMutator.removePaymentMethod(paymentMethod)

assertThat(paymentMethodRemovedTurbine.awaitItem()).isEqualTo(Unit)
assertThat(postPaymentMethodRemovedTurbine.awaitItem()).isEqualTo(Unit)

assertThat(repository.detachRequests.awaitItem()).isEqualTo(
FakeCustomerRepository.DetachRequest(
Expand All @@ -602,7 +537,6 @@ class SavedPaymentMethodMutatorTest {
private fun runScenario(
customerRepository: CustomerRepository = FakeCustomerRepository(),
isCbcEligible: Boolean = false,
isEmbedded: Boolean = false,
block: suspend Scenario.() -> Unit
) {
runTest {
Expand All @@ -614,7 +548,8 @@ class SavedPaymentMethodMutatorTest {
selection = selection,
)

val paymentMethodRemovedTurbine = Turbine<Unit>()
val postPaymentMethodRemovedTurbine = Turbine<Unit>()
val prePaymentMethodRemovedTurbine = Turbine<Unit>()
val updatePaymentMethodTurbine = Turbine<UpdateCall>()
val navigationPopTurbine = Turbine<Unit>()

Expand All @@ -635,7 +570,8 @@ class SavedPaymentMethodMutatorTest {
selection = selection,
clearSelection = { selection.value = null },
customerStateHolder = customerStateHolder,
onPaymentMethodRemoved = { paymentMethodRemovedTurbine.add(Unit) },
prePaymentMethodRemoveActions = { prePaymentMethodRemovedTurbine.add(Unit) },
postPaymentMethodRemoveActions = { postPaymentMethodRemovedTurbine.add(Unit) },
onUpdatePaymentMethod = { displayableSavedPaymentMethod, canRemove, performRemove, updateExecutor ->
updatePaymentMethodTurbine.add(
UpdateCall(displayableSavedPaymentMethod, canRemove, performRemove, updateExecutor)
Expand All @@ -644,14 +580,14 @@ class SavedPaymentMethodMutatorTest {
navigationPop = { navigationPopTurbine.add(Unit) },
isLinkEnabled = stateFlowOf(false),
isNotPaymentFlow = true,
isEmbedded = isEmbedded,
)
Scenario(
savedPaymentMethodMutator = savedPaymentMethodMutator,
customerStateHolder = customerStateHolder,
selectionSource = selection,
currentScreen = currentScreen,
paymentMethodRemovedTurbine = paymentMethodRemovedTurbine,
prePaymentMethodRemovedTurbine = prePaymentMethodRemovedTurbine,
postPaymentMethodRemovedTurbine = postPaymentMethodRemovedTurbine,
updatePaymentMethodTurbine = updatePaymentMethodTurbine,
navigationPopTurbine = navigationPopTurbine,
testScope = this,
Expand All @@ -661,7 +597,7 @@ class SavedPaymentMethodMutatorTest {

advanceUntilIdle()

paymentMethodRemovedTurbine.ensureAllEventsConsumed()
postPaymentMethodRemovedTurbine.ensureAllEventsConsumed()
updatePaymentMethodTurbine.ensureAllEventsConsumed()
navigationPopTurbine.ensureAllEventsConsumed()
}
Expand All @@ -672,7 +608,8 @@ class SavedPaymentMethodMutatorTest {
val customerStateHolder: CustomerStateHolder,
val selectionSource: MutableStateFlow<PaymentSelection?>,
val currentScreen: MutableStateFlow<PaymentSheetScreen>,
val paymentMethodRemovedTurbine: ReceiveTurbine<Unit>,
val prePaymentMethodRemovedTurbine: ReceiveTurbine<Unit>,
val postPaymentMethodRemovedTurbine: ReceiveTurbine<Unit>,
val updatePaymentMethodTurbine: ReceiveTurbine<UpdateCall>,
val navigationPopTurbine: ReceiveTurbine<Unit>,
val testScope: TestScope,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,4 +342,22 @@ internal class NavigationHandlerTest {
verify(screenTwo as Closeable).close()
}
}

@Test
fun `previousScreen value is correct`() = runTest {
val navigationHandler = NavigationHandler<PaymentSheetScreen>(this, PaymentSheetScreen.Loading) {}
navigationHandler.previousScreen.test {
// Initially, there is no previous screen.
assertThat(awaitItem()).isNull()

val screenOne = mock<PaymentSheetScreen>()
navigationHandler.transitionTo(screenOne)
// The previous screen doesn't get updated here -- Loading is removed from the backstack as part of the
// initial loading. The previous screen is still null.

val screenTwo = mock<PaymentSheetScreen>()
navigationHandler.transitionTo(screenTwo)
assertThat(awaitItem()).isEqualTo(screenOne)
}
}
}

0 comments on commit f6711d1

Please sign in to comment.