diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdProvider.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdProvider.java index 2c8b017be..418c515d4 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdProvider.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdProvider.java @@ -28,9 +28,10 @@ public class FlagdProvider extends EventProvider implements FeatureProvider { private final ReadWriteLock lock = new ReentrantReadWriteLock(); private final Resolver flagResolver; - private ProviderState state = ProviderState.NOT_READY; + private EvaluationContext evaluationContext; + /** * Create a new FlagdProvider instance with default options. */ @@ -60,6 +61,7 @@ public FlagdProvider(final FlagdOptions options) { @Override public void initialize(EvaluationContext evaluationContext) throws Exception { + this.evaluationContext = evaluationContext; this.flagResolver.init(); } @@ -91,27 +93,35 @@ public Metadata getMetadata() { @Override public ProviderEvaluation getBooleanEvaluation(String key, Boolean defaultValue, EvaluationContext ctx) { - return this.flagResolver.booleanEvaluation(key, defaultValue, ctx); + return this.flagResolver.booleanEvaluation(key, defaultValue, mergeContext(ctx)); } @Override public ProviderEvaluation getStringEvaluation(String key, String defaultValue, EvaluationContext ctx) { - return this.flagResolver.stringEvaluation(key, defaultValue, ctx); + return this.flagResolver.stringEvaluation(key, defaultValue, mergeContext(ctx)); } @Override public ProviderEvaluation getDoubleEvaluation(String key, Double defaultValue, EvaluationContext ctx) { - return this.flagResolver.doubleEvaluation(key, defaultValue, ctx); + return this.flagResolver.doubleEvaluation(key, defaultValue, mergeContext(ctx)); } @Override public ProviderEvaluation getIntegerEvaluation(String key, Integer defaultValue, EvaluationContext ctx) { - return this.flagResolver.integerEvaluation(key, defaultValue, ctx); + return this.flagResolver.integerEvaluation(key, defaultValue, mergeContext(ctx)); } @Override public ProviderEvaluation getObjectEvaluation(String key, Value defaultValue, EvaluationContext ctx) { - return this.flagResolver.objectEvaluation(key, defaultValue, ctx); + return this.flagResolver.objectEvaluation(key, defaultValue, mergeContext(ctx)); + } + + private EvaluationContext mergeContext(final EvaluationContext clientCallCtx) { + if (this.evaluationContext != null) { + return evaluationContext.merge(clientCallCtx); + } + + return clientCallCtx; } private void setState(ProviderState newState) { diff --git a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdProviderTest.java b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdProviderTest.java index 468b0145d..a5e6f3138 100644 --- a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdProviderTest.java +++ b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdProviderTest.java @@ -1,10 +1,11 @@ package dev.openfeature.contrib.providers.flagd; import com.google.protobuf.Struct; -import dev.openfeature.contrib.providers.flagd.resolver.grpc.cache.Cache; -import dev.openfeature.contrib.providers.flagd.resolver.grpc.cache.CacheType; +import dev.openfeature.contrib.providers.flagd.resolver.Resolver; import dev.openfeature.contrib.providers.flagd.resolver.grpc.GrpcConnector; import dev.openfeature.contrib.providers.flagd.resolver.grpc.GrpcResolver; +import dev.openfeature.contrib.providers.flagd.resolver.grpc.cache.Cache; +import dev.openfeature.contrib.providers.flagd.resolver.grpc.cache.CacheType; import dev.openfeature.flagd.grpc.Schema.EventStreamResponse; import dev.openfeature.flagd.grpc.Schema.ResolveBooleanRequest; import dev.openfeature.flagd.grpc.Schema.ResolveBooleanResponse; @@ -16,6 +17,7 @@ import dev.openfeature.flagd.grpc.ServiceGrpc.ServiceBlockingStub; import dev.openfeature.flagd.grpc.ServiceGrpc.ServiceStub; import dev.openfeature.sdk.FlagEvaluationDetails; +import dev.openfeature.sdk.ImmutableContext; import dev.openfeature.sdk.ImmutableMetadata; import dev.openfeature.sdk.MutableContext; import dev.openfeature.sdk.MutableStructure; @@ -48,6 +50,7 @@ import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; class FlagdProviderTest { @@ -763,6 +766,39 @@ void disabled_cache() { assertEquals(STATIC_REASON, objectDetails.getReason()); } + @Test + void contextMerging() throws Exception { + // given + final FlagdProvider provider = new FlagdProvider(); + + final Resolver resolverMock = mock(Resolver.class); + + Field flagResolver = FlagdProvider.class.getDeclaredField("flagResolver"); + flagResolver.setAccessible(true); + flagResolver.set(provider, resolverMock); + + final HashMap globalCtxMap = new HashMap<>(); + globalCtxMap.put("id", new Value("GlobalID")); + globalCtxMap.put("env", new Value("A")); + + final HashMap localCtxMap = new HashMap<>(); + localCtxMap.put("id", new Value("localID")); + localCtxMap.put("client", new Value("999")); + + final HashMap expectedCtx = new HashMap<>(); + expectedCtx.put("id", new Value("localID")); + expectedCtx.put("env", new Value("A")); + localCtxMap.put("client", new Value("999")); + + // when + provider.initialize(new ImmutableContext(globalCtxMap)); + provider.getBooleanEvaluation("ket", false, new ImmutableContext(localCtxMap)); + + // then + verify(resolverMock).booleanEvaluation(any(), any(), argThat( + ctx -> ctx.asMap().entrySet().containsAll(expectedCtx.entrySet()))); + } + // test utils // create provider with given grpc connector