Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Abort discovery on bookmark failures and continue on authorization expired error #1043

Merged
merged 1 commit into from
Oct 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
import org.neo4j.driver.Bookmark;
import org.neo4j.driver.Logger;
import org.neo4j.driver.Logging;
import org.neo4j.driver.exceptions.AuthorizationExpiredException;
import org.neo4j.driver.exceptions.ClientException;
import org.neo4j.driver.exceptions.DiscoveryException;
import org.neo4j.driver.exceptions.FatalDiscoveryException;
import org.neo4j.driver.exceptions.SecurityException;
Expand Down Expand Up @@ -61,6 +63,8 @@ public class RediscoveryImpl implements Rediscovery
private static final String RECOVERABLE_DISCOVERY_ERROR_WITH_SERVER = "Received a recoverable discovery error with server '%s', " +
"will continue discovery with other routing servers if available. " +
"Complete failure is reported separately from this entry.";
private static final String INVALID_BOOKMARK_CODE = "Neo.ClientError.Transaction.InvalidBookmark";
private static final String INVALID_BOOKMARK_MIXTURE_CODE = "Neo.ClientError.Transaction.InvalidBookmarkMixture";

private final BoltServerAddress initialRouter;
private final RoutingSettings settings;
Expand Down Expand Up @@ -278,10 +282,8 @@ private CompletionStage<ClusterComposition> lookupOnRouter( BoltServerAddress ro
private ClusterComposition handleRoutingProcedureError( Throwable error, RoutingTable routingTable,
BoltServerAddress routerAddress, Throwable baseError )
{
if ( error instanceof SecurityException || error instanceof FatalDiscoveryException ||
(error instanceof IllegalStateException && ConnectionPool.CONNECTION_POOL_CLOSED_ERROR_MESSAGE.equals( error.getMessage() )) )
if ( mustAbortDiscovery( error ) )
{
// auth error or routing error happened, terminate the discovery procedure immediately
throw new CompletionException( error );
}

Expand All @@ -295,6 +297,31 @@ private ClusterComposition handleRoutingProcedureError( Throwable error, Routing
return null;
}

private boolean mustAbortDiscovery( Throwable throwable )
{
boolean abort = false;

if ( !(throwable instanceof AuthorizationExpiredException) && throwable instanceof SecurityException )
{
abort = true;
}
else if ( throwable instanceof FatalDiscoveryException )
{
abort = true;
}
else if ( throwable instanceof IllegalStateException && ConnectionPool.CONNECTION_POOL_CLOSED_ERROR_MESSAGE.equals( throwable.getMessage() ) )
{
abort = true;
}
else if ( throwable instanceof ClientException )
{
String code = ((ClientException) throwable).code();
abort = INVALID_BOOKMARK_CODE.equals( code ) || INVALID_BOOKMARK_MIXTURE_CODE.equals( code );
}

return abort;
}

@Override
public List<BoltServerAddress> resolve() throws UnknownHostException
{
Expand Down
56 changes: 38 additions & 18 deletions driver/src/main/java/org/neo4j/driver/internal/util/ErrorUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.neo4j.driver.exceptions.FatalDiscoveryException;
import org.neo4j.driver.exceptions.Neo4jException;
import org.neo4j.driver.exceptions.ResultConsumedException;
import org.neo4j.driver.exceptions.SecurityException;
import org.neo4j.driver.exceptions.ServiceUnavailableException;
import org.neo4j.driver.exceptions.TokenExpiredException;
import org.neo4j.driver.exceptions.TransientException;
Expand Down Expand Up @@ -65,29 +66,38 @@ public static ResultConsumedException newResultConsumedError()

public static Neo4jException newNeo4jError( String code, String message )
{
String classification = extractClassification( code );
switch ( classification )
switch ( extractErrorClass( code ) )
{
case "ClientError":
if ( code.equalsIgnoreCase( "Neo.ClientError.Security.Unauthorized" ) )
if ( "Security".equals( extractErrorSubClass( code ) ) )
{
return new AuthenticationException( code, message );
}
else if ( code.equalsIgnoreCase( "Neo.ClientError.Database.DatabaseNotFound" ) )
{
return new FatalDiscoveryException( code, message );
}
else if ( code.equalsIgnoreCase( "Neo.ClientError.Security.AuthorizationExpired" ) )
{
return new AuthorizationExpiredException( code, message );
}
else if ( code.equalsIgnoreCase( "Neo.ClientError.Security.TokenExpired" ) )
{
return new TokenExpiredException( code, message );
if ( code.equalsIgnoreCase( "Neo.ClientError.Security.Unauthorized" ) )
{
return new AuthenticationException( code, message );
}
else if ( code.equalsIgnoreCase( "Neo.ClientError.Security.AuthorizationExpired" ) )
{
return new AuthorizationExpiredException( code, message );
}
else if ( code.equalsIgnoreCase( "Neo.ClientError.Security.TokenExpired" ) )
{
return new TokenExpiredException( code, message );
}
else
{
return new SecurityException( code, message );
}
}
else
{
return new ClientException( code, message );
if ( code.equalsIgnoreCase( "Neo.ClientError.Database.DatabaseNotFound" ) )
{
return new FatalDiscoveryException( code, message );
}
else
{
return new ClientException( code, message );
}
}
case "TransientError":
return new TransientException( code, message );
Expand Down Expand Up @@ -140,7 +150,7 @@ private static boolean isClientOrTransientError( Neo4jException error )
return errorCode != null && (errorCode.contains( "ClientError" ) || errorCode.contains( "TransientError" ));
}

private static String extractClassification( String code )
private static String extractErrorClass( String code )
{
String[] parts = code.split( "\\." );
if ( parts.length < 2 )
Expand All @@ -150,6 +160,16 @@ private static String extractClassification( String code )
return parts[1];
}

private static String extractErrorSubClass( String code )
{
String[] parts = code.split( "\\." );
if ( parts.length < 3 )
{
return "";
}
return parts[2];
}

public static void addSuppressed( Throwable mainError, Throwable error )
{
if ( mainError != error )
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import io.netty.util.concurrent.GlobalEventExecutor;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;

import java.io.IOException;
Expand All @@ -33,6 +35,8 @@
import org.neo4j.driver.Logger;
import org.neo4j.driver.Logging;
import org.neo4j.driver.exceptions.AuthenticationException;
import org.neo4j.driver.exceptions.AuthorizationExpiredException;
import org.neo4j.driver.exceptions.ClientException;
import org.neo4j.driver.exceptions.DiscoveryException;
import org.neo4j.driver.exceptions.ProtocolException;
import org.neo4j.driver.exceptions.ServiceUnavailableException;
Expand Down Expand Up @@ -143,6 +147,67 @@ void shouldFailImmediatelyOnAuthError()
verify( table ).forget( A );
}

@Test
void shouldUseAnotherRouterOnAuthorizationExpiredException()
{
ClusterComposition expectedComposition =
new ClusterComposition( 42, asOrderedSet( A, B, C ), asOrderedSet( B, C, D ), asOrderedSet( A, B ), null );

Map<BoltServerAddress,Object> responsesByAddress = new HashMap<>();
responsesByAddress.put( A, new AuthorizationExpiredException( "Neo.ClientError.Security.AuthorizationExpired", "message" ) );
responsesByAddress.put( B, expectedComposition );

ClusterCompositionProvider compositionProvider = compositionProviderMock( responsesByAddress );
Rediscovery rediscovery = newRediscovery( A, compositionProvider, mock( ServerAddressResolver.class ) );
RoutingTable table = routingTableMock( A, B, C );

ClusterComposition actualComposition = await( rediscovery.lookupClusterComposition( table, pool, empty(), null ) ).getClusterComposition();

assertEquals( expectedComposition, actualComposition );
verify( table ).forget( A );
verify( table, never() ).forget( B );
verify( table, never() ).forget( C );
}

@ParameterizedTest
@ValueSource( strings = {"Neo.ClientError.Transaction.InvalidBookmark", "Neo.ClientError.Transaction.InvalidBookmarkMixture"} )
void shouldFailImmediatelyOnBookmarkErrors( String code )
{
ClientException error = new ClientException( code, "Invalid" );

Map<BoltServerAddress,Object> responsesByAddress = new HashMap<>();
responsesByAddress.put( A, new RuntimeException( "Hi!" ) );
responsesByAddress.put( B, error );

ClusterCompositionProvider compositionProvider = compositionProviderMock( responsesByAddress );
Rediscovery rediscovery = newRediscovery( A, compositionProvider, mock( ServerAddressResolver.class ) );
RoutingTable table = routingTableMock( A, B, C );

ClientException actualError = assertThrows( ClientException.class,
() -> await( rediscovery.lookupClusterComposition( table, pool, empty(), null ) ) );
assertEquals( error, actualError );
verify( table ).forget( A );
}

@Test
void shouldFailImmediatelyOnClosedPoolError()
{
IllegalStateException error = new IllegalStateException( ConnectionPool.CONNECTION_POOL_CLOSED_ERROR_MESSAGE );

Map<BoltServerAddress,Object> responsesByAddress = new HashMap<>();
responsesByAddress.put( A, new RuntimeException( "Hi!" ) );
responsesByAddress.put( B, error );

ClusterCompositionProvider compositionProvider = compositionProviderMock( responsesByAddress );
Rediscovery rediscovery = newRediscovery( A, compositionProvider, mock( ServerAddressResolver.class ) );
RoutingTable table = routingTableMock( A, B, C );

IllegalStateException actualError = assertThrows( IllegalStateException.class,
() -> await( rediscovery.lookupClusterComposition( table, pool, empty(), null ) ) );
assertEquals( error, actualError );
verify( table ).forget( A );
}

@Test
void shouldFallbackToInitialRouterWhenKnownRoutersFail()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ public class GetFeatures implements TestkitRequest
"Feature:Auth:Kerberos",
"Feature:Auth:Custom",
"Feature:Bolt:4.4",
"Feature:Impersonation"
"Feature:Impersonation",
"Temporary:FastFailingDiscovery"
) );

private static final Set<String> SYNC_FEATURES = new HashSet<>( Arrays.asList(
Expand Down