Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split RestateContext interface in KeyedContext/UnkeyedContext #213

Merged
merged 4 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions examples/src/main/java/dev/restate/sdk/examples/Counter.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
// https://github.com/restatedev/sdk-java/blob/main/LICENSE
package dev.restate.sdk.examples;

import dev.restate.sdk.RestateContext;
import dev.restate.sdk.KeyedContext;
import dev.restate.sdk.common.CoreSerdes;
import dev.restate.sdk.common.StateKey;
import dev.restate.sdk.examples.generated.*;
Expand All @@ -23,30 +23,28 @@ public class Counter extends CounterRestate.CounterRestateImplBase {
private static final StateKey<Long> TOTAL = StateKey.of("total", CoreSerdes.JSON_LONG);

@Override
public void reset(RestateContext ctx, CounterRequest request) {
restateContext().clear(TOTAL);
public void reset(KeyedContext ctx, CounterRequest request) {
ctx.clear(TOTAL);
}

@Override
public void add(RestateContext ctx, CounterAddRequest request) {
public void add(KeyedContext ctx, CounterAddRequest request) {
long currentValue = ctx.get(TOTAL).orElse(0L);
long newValue = currentValue + request.getValue();
ctx.set(TOTAL, newValue);
}

@Override
public GetResponse get(RestateContext context, CounterRequest request) {
long currentValue = restateContext().get(TOTAL).orElse(0L);
public GetResponse get(KeyedContext ctx, CounterRequest request) {
long currentValue = ctx.get(TOTAL).orElse(0L);

return GetResponse.newBuilder().setValue(currentValue).build();
}

@Override
public CounterUpdateResult getAndAdd(RestateContext context, CounterAddRequest request) {
public CounterUpdateResult getAndAdd(KeyedContext ctx, CounterAddRequest request) {
LOG.info("Invoked get and add with " + request.getValue());

RestateContext ctx = restateContext();

long currentValue = ctx.get(TOTAL).orElse(0L);
long newValue = currentValue + request.getValue();
ctx.set(TOTAL, newValue);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
package dev.restate.sdk.examples;

import com.google.protobuf.Empty;
import dev.restate.sdk.RestateContext;
import dev.restate.sdk.KeyedContext;
import dev.restate.sdk.RestateService;
import dev.restate.sdk.common.CoreSerdes;
import dev.restate.sdk.common.StateKey;
Expand All @@ -27,15 +27,15 @@ public class VanillaGrpcCounter extends CounterGrpc.CounterImplBase implements R

@Override
public void reset(CounterRequest request, StreamObserver<Empty> responseObserver) {
restateContext().clear(TOTAL);
KeyedContext.current().clear(TOTAL);

responseObserver.onNext(Empty.getDefaultInstance());
responseObserver.onCompleted();
}

@Override
public void add(CounterAddRequest request, StreamObserver<Empty> responseObserver) {
RestateContext ctx = restateContext();
KeyedContext ctx = KeyedContext.current();

long currentValue = ctx.get(TOTAL).orElse(0L);
long newValue = currentValue + request.getValue();
Expand All @@ -47,7 +47,7 @@ public void add(CounterAddRequest request, StreamObserver<Empty> responseObserve

@Override
public void get(CounterRequest request, StreamObserver<GetResponse> responseObserver) {
long currentValue = restateContext().get(TOTAL).orElse(0L);
long currentValue = KeyedContext.current().get(TOTAL).orElse(0L);

responseObserver.onNext(GetResponse.newBuilder().setValue(currentValue).build());
responseObserver.onCompleted();
Expand All @@ -58,7 +58,7 @@ public void getAndAdd(
CounterAddRequest request, StreamObserver<CounterUpdateResult> responseObserver) {
LOG.info("Invoked get and add with " + request.getValue());

RestateContext ctx = restateContext();
KeyedContext ctx = KeyedContext.current();

long currentValue = ctx.get(TOTAL).orElse(0L);
long newValue = currentValue + request.getValue();
Expand Down
12 changes: 6 additions & 6 deletions examples/src/main/kotlin/dev/restate/sdk/examples/CounterKt.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import dev.restate.sdk.common.CoreSerdes
import dev.restate.sdk.common.StateKey
import dev.restate.sdk.examples.generated.*
import dev.restate.sdk.http.vertx.RestateHttpEndpointBuilder
import dev.restate.sdk.kotlin.RestateContext
import dev.restate.sdk.kotlin.KeyedContext
import org.apache.logging.log4j.LogManager

class CounterKt : CounterRestateKt.CounterRestateKtImplBase() {
Expand All @@ -21,20 +21,20 @@ class CounterKt : CounterRestateKt.CounterRestateKtImplBase() {

private val TOTAL = StateKey.of("total", CoreSerdes.JSON_LONG)

override suspend fun reset(context: RestateContext, request: CounterRequest) {
override suspend fun reset(context: KeyedContext, request: CounterRequest) {
context.clear(TOTAL)
}

override suspend fun add(context: RestateContext, request: CounterAddRequest) {
override suspend fun add(context: KeyedContext, request: CounterAddRequest) {
updateCounter(context, request.value)
}

override suspend fun get(context: RestateContext, request: CounterRequest): GetResponse {
override suspend fun get(context: KeyedContext, request: CounterRequest): GetResponse {
return getResponse { value = context.get(TOTAL) ?: 0L }
}

override suspend fun getAndAdd(
context: RestateContext,
context: KeyedContext,
request: CounterAddRequest
): CounterUpdateResult {
LOG.info("Invoked get and add with " + request.value)
Expand All @@ -45,7 +45,7 @@ class CounterKt : CounterRestateKt.CounterRestateKtImplBase() {
}
}

private suspend fun updateCounter(context: RestateContext, add: Long): Pair<Long, Long> {
private suspend fun updateCounter(context: KeyedContext, add: Long): Pair<Long, Long> {
val currentValue = context.get(TOTAL) ?: 0L
val newValue = currentValue + add

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import com.salesforce.jprotoc.ProtoTypeMap;
import com.salesforce.jprotoc.ProtocPlugin;
import dev.restate.generated.ext.Ext;
import dev.restate.generated.ext.ServiceType;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
Expand Down Expand Up @@ -120,6 +121,12 @@ private ServiceContext buildServiceContext(
serviceContext.serviceName = serviceProto.getName();
serviceContext.deprecated = serviceProto.getOptions().getDeprecated();

// Resolve context type
serviceContext.contextType =
serviceProto.getOptions().getExtension(Ext.serviceType) == ServiceType.UNKEYED
? "UnkeyedContext"
: "KeyedContext";

// Resolve javadoc
DescriptorProtos.SourceCodeInfo.Location serviceLocation =
locations.stream()
Expand Down Expand Up @@ -215,6 +222,7 @@ private static class ServiceContext {
public String packageName;
public String className;
public String serviceName;
public String contextType;
public String apidoc;
public boolean deprecated;
public final List<MethodContext> methods = new ArrayList<>();
Expand Down
42 changes: 18 additions & 24 deletions protoc-gen-restate/src/main/resources/javaStub.mustache
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
package {{packageName}};
{{/packageName}}

import dev.restate.sdk.RestateContext;
import dev.restate.sdk.UnkeyedContext;
import dev.restate.sdk.KeyedContext;
import dev.restate.sdk.Awaitable;
import dev.restate.sdk.common.syscalls.Syscalls;
import java.time.Duration;
Expand All @@ -15,24 +16,17 @@ public class {{className}} {
private {{className}}() {}

/**
* Create a new client.
*/
public static {{serviceName}}RestateClient newClient() {
return newClient(RestateContext.fromSyscalls(Syscalls.current()));
}

/**
* Create a new client from the given {@link RestateContext}.
* Create a new client from the given {@link KeyedContext}.
*/
public static {{serviceName}}RestateClient newClient(RestateContext ctx) {
public static {{serviceName}}RestateClient newClient(UnkeyedContext ctx) {
return new {{serviceName}}RestateClient(ctx);
}

{{{apidoc}}}
public static final class {{serviceName}}RestateClient {
private final RestateContext ctx;
private final UnkeyedContext ctx;

{{serviceName}}RestateClient(RestateContext ctx) {
{{serviceName}}RestateClient(UnkeyedContext ctx) {
this.ctx = ctx;
}

Expand All @@ -57,9 +51,9 @@ public class {{className}} {
}

public static final class {{serviceName}}RestateOneWayClient {
private final RestateContext ctx;
private final UnkeyedContext ctx;

{{serviceName}}RestateOneWayClient(RestateContext ctx) {
{{serviceName}}RestateOneWayClient(UnkeyedContext ctx) {
this.ctx = ctx;
}

Expand All @@ -74,10 +68,10 @@ public class {{className}} {
}

public static final class {{serviceName}}RestateDelayedClient {
private final RestateContext ctx;
private final UnkeyedContext ctx;
private final Duration delay;

{{serviceName}}RestateDelayedClient(RestateContext ctx, Duration delay) {
{{serviceName}}RestateDelayedClient(UnkeyedContext ctx, Duration delay) {
this.ctx = ctx;
this.delay = delay;
}
Expand All @@ -100,7 +94,7 @@ public class {{className}} {
@java.lang.Deprecated
{{/deprecated}}
{{{apidoc}}}
public {{#isOutputEmpty}}void{{/isOutputEmpty}}{{^isOutputEmpty}}{{outputType}}{{/isOutputEmpty}} {{methodName}}(RestateContext context{{^isInputEmpty}}, {{inputType}} request{{/isInputEmpty}}) throws dev.restate.sdk.common.TerminalException {
public {{#isOutputEmpty}}void{{/isOutputEmpty}}{{^isOutputEmpty}}{{outputType}}{{/isOutputEmpty}} {{methodName}}({{contextType}} context{{^isInputEmpty}}, {{inputType}} request{{/isInputEmpty}}) throws dev.restate.sdk.common.TerminalException {
throw new dev.restate.sdk.common.TerminalException(dev.restate.sdk.common.TerminalException.Code.UNIMPLEMENTED);
}

Expand All @@ -120,34 +114,34 @@ public class {{className}} {
private static final class HandlerAdapter<Req, Resp> implements
io.grpc.stub.ServerCalls.UnaryMethod<Req, Resp> {

private final java.util.function.BiFunction<RestateContext, Req, Resp> handler;
private final java.util.function.BiFunction<KeyedContext, Req, Resp> handler;

private HandlerAdapter(java.util.function.BiFunction<RestateContext, Req, Resp> handler) {
private HandlerAdapter(java.util.function.BiFunction<KeyedContext, Req, Resp> handler) {
this.handler = handler;
}

@Override
public void invoke(Req request, io.grpc.stub.StreamObserver<Resp> responseObserver) {
responseObserver.onNext(handler.apply(RestateContext.fromSyscalls(Syscalls.current()), request));
responseObserver.onNext(handler.apply(KeyedContext.fromSyscalls(Syscalls.current()), request));
responseObserver.onCompleted();
}

private static <Req, Resp> HandlerAdapter<Req, Resp> of(java.util.function.BiFunction<RestateContext, Req, Resp> handler) {
private static <Req, Resp> HandlerAdapter<Req, Resp> of(java.util.function.BiFunction<KeyedContext, Req, Resp> handler) {
return new HandlerAdapter<>(handler);
}

private static <Resp> HandlerAdapter<com.google.protobuf.Empty, Resp> of(java.util.function.Function<RestateContext, Resp> handler) {
private static <Resp> HandlerAdapter<com.google.protobuf.Empty, Resp> of(java.util.function.Function<KeyedContext, Resp> handler) {
return new HandlerAdapter<>((ctx, e) -> handler.apply(ctx));
}

private static <Req> HandlerAdapter<Req, com.google.protobuf.Empty> of(java.util.function.BiConsumer<RestateContext, Req> handler) {
private static <Req> HandlerAdapter<Req, com.google.protobuf.Empty> of(java.util.function.BiConsumer<KeyedContext, Req> handler) {
return new HandlerAdapter<>((ctx, req) -> {
handler.accept(ctx, req);
return com.google.protobuf.Empty.getDefaultInstance();
});
}

private static HandlerAdapter<com.google.protobuf.Empty, com.google.protobuf.Empty> of(java.util.function.Consumer<RestateContext> handler) {
private static HandlerAdapter<com.google.protobuf.Empty, com.google.protobuf.Empty> of(java.util.function.Consumer<KeyedContext> handler) {
return new HandlerAdapter<>((ctx, req) -> {
handler.accept(ctx);
return com.google.protobuf.Empty.getDefaultInstance();
Expand Down
24 changes: 12 additions & 12 deletions protoc-gen-restate/src/main/resources/ktStub.mustache
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
package {{packageName}};
{{/packageName}}

import dev.restate.sdk.kotlin.RestateContext;
import dev.restate.sdk.kotlin.UnkeyedContext;
import dev.restate.sdk.kotlin.KeyedContext;
import dev.restate.sdk.kotlin.Awaitable;
import dev.restate.sdk.kotlin.RestateKtService;
import dev.restate.sdk.common.syscalls.Syscalls;
import io.grpc.kotlin.ClientCalls.unaryRpc
import io.grpc.kotlin.ServerCalls.unaryServerMethodDefinition
import {{packageName}}.{{serviceName}}Grpc.getServiceDescriptor;
import dev.restate.sdk.kotlin.restateContextFromSyscalls;
import kotlin.time.Duration

{{#deprecated}}
Expand All @@ -18,14 +18,14 @@ import kotlin.time.Duration
public object {{className}} {

/**
* Create a new client from the given [RestateContext].
* Create a new client from the given [UnkeyedContext].
*/
fun newClient(ctx: RestateContext): {{serviceName}}RestateKtClient {
fun newClient(ctx: UnkeyedContext): {{serviceName}}RestateKtClient {
return {{serviceName}}RestateKtClient(ctx);
}

{{{javadoc}}}
public class {{serviceName}}RestateKtClient(private val ctx: RestateContext) {
public class {{serviceName}}RestateKtClient(private val ctx: UnkeyedContext) {
// Create a variant of this client to execute oneWay calls.
public fun oneWay(): {{serviceName}}RestateKtOneWayClient {
return {{serviceName}}RestateKtOneWayClient(ctx);
Expand All @@ -46,7 +46,7 @@ public object {{className}} {
{{/methods}}
}

public class {{serviceName}}RestateKtOneWayClient(private val ctx: RestateContext) {
public class {{serviceName}}RestateKtOneWayClient(private val ctx: UnkeyedContext) {
{{#methods}}
{{#deprecated}}@Deprecated{{/deprecated}}
{{{javadoc}}}
Expand All @@ -57,7 +57,7 @@ public object {{className}} {
{{/methods}}
}

public class {{serviceName}}RestateKtDelayedClient(private val ctx: RestateContext, private val delay: Duration) {
public class {{serviceName}}RestateKtDelayedClient(private val ctx: UnkeyedContext, private val delay: Duration) {
{{#methods}}
{{#deprecated}}@Deprecated{{/deprecated}}
{{{javadoc}}}
Expand All @@ -78,7 +78,7 @@ public object {{className}} {
@Deprecated
{{/deprecated}}
{{{javadoc}}}
public open suspend fun {{methodName}}(context: RestateContext{{^isInputEmpty}}, request: {{inputType}} {{/isInputEmpty}}){{^isOutputEmpty}}: {{outputType}}{{/isOutputEmpty}} {
public open suspend fun {{methodName}}(context: {{contextType}}{{^isInputEmpty}}, request: {{inputType}} {{/isInputEmpty}}){{^isOutputEmpty}}: {{outputType}}{{/isOutputEmpty}} {
throw dev.restate.sdk.common.TerminalException(dev.restate.sdk.common.TerminalException.Code.UNIMPLEMENTED);
}

Expand All @@ -91,18 +91,18 @@ public object {{className}} {
descriptor = {{packageName}}.{{serviceName}}Grpc.{{methodDescriptorGetter}}(),
implementation = {
{{#isInputEmpty}}{{#isOutputEmpty}}
{{methodName}}(restateContextFromSyscalls(Syscalls.current()))
{{methodName}}(KeyedContext.fromSyscalls(Syscalls.current()))
return@unaryServerMethodDefinition com.google.protobuf.Empty.getDefaultInstance()
{{/isOutputEmpty}}{{/isInputEmpty}}
{{#isInputEmpty}}{{^isOutputEmpty}}
return@unaryServerMethodDefinition {{methodName}}(restateContextFromSyscalls(Syscalls.current()))
return@unaryServerMethodDefinition {{methodName}}(KeyedContext.fromSyscalls(Syscalls.current()))
{{/isOutputEmpty}}{{/isInputEmpty}}
{{^isInputEmpty}}{{#isOutputEmpty}}
{{methodName}}(restateContextFromSyscalls(Syscalls.current()), it)
{{methodName}}(KeyedContext.fromSyscalls(Syscalls.current()), it)
return@unaryServerMethodDefinition com.google.protobuf.Empty.getDefaultInstance()
{{/isOutputEmpty}}{{/isInputEmpty}}
{{^isInputEmpty}}{{^isOutputEmpty}}
return@unaryServerMethodDefinition {{methodName}}(restateContextFromSyscalls(Syscalls.current()), it)
return@unaryServerMethodDefinition {{methodName}}(KeyedContext.fromSyscalls(Syscalls.current()), it)
{{/isOutputEmpty}}{{/isInputEmpty}}
}
))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ import kotlin.time.Duration
import kotlin.time.toJavaDuration
import kotlinx.coroutines.*

internal class RestateContextImpl internal constructor(private val syscalls: Syscalls) :
RestateContext {
internal class ContextImpl internal constructor(private val syscalls: Syscalls) : KeyedContext {
override suspend fun <T : Any> get(key: StateKey<T>): T? {
val deferred: Deferred<ByteString> =
suspendCancellableCoroutine { cont: CancellableContinuation<Deferred<ByteString>> ->
Expand Down
Loading
Loading