Skip to content

Commit

Permalink
Abort discovery on bookmark failures and continue on authorization ex…
Browse files Browse the repository at this point in the history
…pired error

This update ensures that discovery gets aborted on `ClientException` with the following codes:
- `Neo.ClientError.Transaction.InvalidBookmark`
- `Neo.ClientError.Transaction.InvalidBookmarkMixture`

In addition, it makes sure that it continues on `AuthorizationExpiredException`.

All security exceptions are mapped to `SecurityException`.
  • Loading branch information
injectives committed Oct 26, 2021
1 parent 4ae0f2c commit bf57016
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
import org.neo4j.driver.Bookmark;
import org.neo4j.driver.Logger;
import org.neo4j.driver.Logging;
import org.neo4j.driver.exceptions.AuthorizationExpiredException;
import org.neo4j.driver.exceptions.ClientException;
import org.neo4j.driver.exceptions.DiscoveryException;
import org.neo4j.driver.exceptions.FatalDiscoveryException;
import org.neo4j.driver.exceptions.SecurityException;
Expand Down Expand Up @@ -61,6 +63,8 @@ public class RediscoveryImpl implements Rediscovery
private static final String RECOVERABLE_DISCOVERY_ERROR_WITH_SERVER = "Received a recoverable discovery error with server '%s', " +
"will continue discovery with other routing servers if available. " +
"Complete failure is reported separately from this entry.";
private static final String INVALID_BOOKMARK_CODE = "Neo.ClientError.Transaction.InvalidBookmark";
private static final String INVALID_BOOKMARK_MIXTURE_CODE = "Neo.ClientError.Transaction.InvalidBookmarkMixture";

private final BoltServerAddress initialRouter;
private final RoutingSettings settings;
Expand Down Expand Up @@ -278,10 +282,8 @@ private CompletionStage<ClusterComposition> lookupOnRouter( BoltServerAddress ro
private ClusterComposition handleRoutingProcedureError( Throwable error, RoutingTable routingTable,
BoltServerAddress routerAddress, Throwable baseError )
{
if ( error instanceof SecurityException || error instanceof FatalDiscoveryException ||
(error instanceof IllegalStateException && ConnectionPool.CONNECTION_POOL_CLOSED_ERROR_MESSAGE.equals( error.getMessage() )) )
if ( mustAbortDiscovery( error ) )
{
// auth error or routing error happened, terminate the discovery procedure immediately
throw new CompletionException( error );
}

Expand All @@ -295,6 +297,31 @@ private ClusterComposition handleRoutingProcedureError( Throwable error, Routing
return null;
}

private boolean mustAbortDiscovery( Throwable throwable )
{
boolean abort = false;

if ( !(throwable instanceof AuthorizationExpiredException) && throwable instanceof SecurityException )
{
abort = true;
}
else if ( throwable instanceof FatalDiscoveryException )
{
abort = true;
}
else if ( throwable instanceof IllegalStateException && ConnectionPool.CONNECTION_POOL_CLOSED_ERROR_MESSAGE.equals( throwable.getMessage() ) )
{
abort = true;
}
else if ( throwable instanceof ClientException )
{
String code = ((ClientException) throwable).code();
abort = INVALID_BOOKMARK_CODE.equals( code ) || INVALID_BOOKMARK_MIXTURE_CODE.equals( code );
}

return abort;
}

@Override
public List<BoltServerAddress> resolve() throws UnknownHostException
{
Expand Down
56 changes: 38 additions & 18 deletions driver/src/main/java/org/neo4j/driver/internal/util/ErrorUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.neo4j.driver.exceptions.FatalDiscoveryException;
import org.neo4j.driver.exceptions.Neo4jException;
import org.neo4j.driver.exceptions.ResultConsumedException;
import org.neo4j.driver.exceptions.SecurityException;
import org.neo4j.driver.exceptions.ServiceUnavailableException;
import org.neo4j.driver.exceptions.TokenExpiredException;
import org.neo4j.driver.exceptions.TransientException;
Expand Down Expand Up @@ -65,29 +66,38 @@ public static ResultConsumedException newResultConsumedError()

public static Neo4jException newNeo4jError( String code, String message )
{
String classification = extractClassification( code );
switch ( classification )
switch ( extractErrorClass( code ) )
{
case "ClientError":
if ( code.equalsIgnoreCase( "Neo.ClientError.Security.Unauthorized" ) )
if ( "Security".equals( extractErrorSubClass( code ) ) )
{
return new AuthenticationException( code, message );
}
else if ( code.equalsIgnoreCase( "Neo.ClientError.Database.DatabaseNotFound" ) )
{
return new FatalDiscoveryException( code, message );
}
else if ( code.equalsIgnoreCase( "Neo.ClientError.Security.AuthorizationExpired" ) )
{
return new AuthorizationExpiredException( code, message );
}
else if ( code.equalsIgnoreCase( "Neo.ClientError.Security.TokenExpired" ) )
{
return new TokenExpiredException( code, message );
if ( code.equalsIgnoreCase( "Neo.ClientError.Security.Unauthorized" ) )
{
return new AuthenticationException( code, message );
}
else if ( code.equalsIgnoreCase( "Neo.ClientError.Security.AuthorizationExpired" ) )
{
return new AuthorizationExpiredException( code, message );
}
else if ( code.equalsIgnoreCase( "Neo.ClientError.Security.TokenExpired" ) )
{
return new TokenExpiredException( code, message );
}
else
{
return new SecurityException( code, message );
}
}
else
{
return new ClientException( code, message );
if ( code.equalsIgnoreCase( "Neo.ClientError.Database.DatabaseNotFound" ) )
{
return new FatalDiscoveryException( code, message );
}
else
{
return new ClientException( code, message );
}
}
case "TransientError":
return new TransientException( code, message );
Expand Down Expand Up @@ -140,7 +150,7 @@ private static boolean isClientOrTransientError( Neo4jException error )
return errorCode != null && (errorCode.contains( "ClientError" ) || errorCode.contains( "TransientError" ));
}

private static String extractClassification( String code )
private static String extractErrorClass( String code )
{
String[] parts = code.split( "\\." );
if ( parts.length < 2 )
Expand All @@ -150,6 +160,16 @@ private static String extractClassification( String code )
return parts[1];
}

private static String extractErrorSubClass( String code )
{
String[] parts = code.split( "\\." );
if ( parts.length < 3 )
{
return "";
}
return parts[2];
}

public static void addSuppressed( Throwable mainError, Throwable error )
{
if ( mainError != error )
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import io.netty.util.concurrent.GlobalEventExecutor;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;

import java.io.IOException;
Expand All @@ -33,6 +35,8 @@
import org.neo4j.driver.Logger;
import org.neo4j.driver.Logging;
import org.neo4j.driver.exceptions.AuthenticationException;
import org.neo4j.driver.exceptions.AuthorizationExpiredException;
import org.neo4j.driver.exceptions.ClientException;
import org.neo4j.driver.exceptions.DiscoveryException;
import org.neo4j.driver.exceptions.ProtocolException;
import org.neo4j.driver.exceptions.ServiceUnavailableException;
Expand Down Expand Up @@ -143,6 +147,67 @@ void shouldFailImmediatelyOnAuthError()
verify( table ).forget( A );
}

@Test
void shouldUseAnotherRouterOnAuthorizationExpiredException()
{
ClusterComposition expectedComposition =
new ClusterComposition( 42, asOrderedSet( A, B, C ), asOrderedSet( B, C, D ), asOrderedSet( A, B ), null );

Map<BoltServerAddress,Object> responsesByAddress = new HashMap<>();
responsesByAddress.put( A, new AuthorizationExpiredException( "Neo.ClientError.Security.AuthorizationExpired", "message" ) );
responsesByAddress.put( B, expectedComposition );

ClusterCompositionProvider compositionProvider = compositionProviderMock( responsesByAddress );
Rediscovery rediscovery = newRediscovery( A, compositionProvider, mock( ServerAddressResolver.class ) );
RoutingTable table = routingTableMock( A, B, C );

ClusterComposition actualComposition = await( rediscovery.lookupClusterComposition( table, pool, empty(), null ) ).getClusterComposition();

assertEquals( expectedComposition, actualComposition );
verify( table ).forget( A );
verify( table, never() ).forget( B );
verify( table, never() ).forget( C );
}

@ParameterizedTest
@ValueSource( strings = {"Neo.ClientError.Transaction.InvalidBookmark", "Neo.ClientError.Transaction.InvalidBookmarkMixture"} )
void shouldFailImmediatelyOnBookmarkErrors( String code )
{
ClientException error = new ClientException( code, "Invalid" );

Map<BoltServerAddress,Object> responsesByAddress = new HashMap<>();
responsesByAddress.put( A, new RuntimeException( "Hi!" ) );
responsesByAddress.put( B, error );

ClusterCompositionProvider compositionProvider = compositionProviderMock( responsesByAddress );
Rediscovery rediscovery = newRediscovery( A, compositionProvider, mock( ServerAddressResolver.class ) );
RoutingTable table = routingTableMock( A, B, C );

ClientException actualError = assertThrows( ClientException.class,
() -> await( rediscovery.lookupClusterComposition( table, pool, empty(), null ) ) );
assertEquals( error, actualError );
verify( table ).forget( A );
}

@Test
void shouldFailImmediatelyOnClosedPoolError()
{
IllegalStateException error = new IllegalStateException( ConnectionPool.CONNECTION_POOL_CLOSED_ERROR_MESSAGE );

Map<BoltServerAddress,Object> responsesByAddress = new HashMap<>();
responsesByAddress.put( A, new RuntimeException( "Hi!" ) );
responsesByAddress.put( B, error );

ClusterCompositionProvider compositionProvider = compositionProviderMock( responsesByAddress );
Rediscovery rediscovery = newRediscovery( A, compositionProvider, mock( ServerAddressResolver.class ) );
RoutingTable table = routingTableMock( A, B, C );

IllegalStateException actualError = assertThrows( IllegalStateException.class,
() -> await( rediscovery.lookupClusterComposition( table, pool, empty(), null ) ) );
assertEquals( error, actualError );
verify( table ).forget( A );
}

@Test
void shouldFallbackToInitialRouterWhenKnownRoutersFail()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ public class GetFeatures implements TestkitRequest
"Feature:Auth:Kerberos",
"Feature:Auth:Custom",
"Feature:Bolt:4.4",
"Feature:Impersonation"
"Feature:Impersonation",
"Temporary:FastFailingDiscovery"
) );

private static final Set<String> SYNC_FEATURES = new HashSet<>( Arrays.asList(
Expand Down

0 comments on commit bf57016

Please sign in to comment.