Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Shared handlers annotation for virtual objects #288

Merged
merged 1 commit into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion examples/src/main/java/my/restate/sdk/examples/Counter.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
package my.restate.sdk.examples;

import dev.restate.sdk.ObjectContext;
import dev.restate.sdk.SharedObjectContext;
import dev.restate.sdk.annotation.Handler;
import dev.restate.sdk.annotation.Shared;
import dev.restate.sdk.annotation.VirtualObject;
import dev.restate.sdk.common.CoreSerdes;
import dev.restate.sdk.common.StateKey;
Expand All @@ -36,8 +38,9 @@ public void add(ObjectContext ctx, Long request) {
ctx.set(TOTAL, newValue);
}

@Shared
@Handler
public Long get(ObjectContext ctx) {
public Long get(SharedObjectContext ctx) {
return ctx.get(TOTAL).orElse(0L);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
package my.restate.sdk.examples

import dev.restate.sdk.annotation.Handler
import dev.restate.sdk.annotation.Shared
import dev.restate.sdk.annotation.VirtualObject
import dev.restate.sdk.http.vertx.RestateHttpEndpointBuilder
import dev.restate.sdk.kotlin.KtStateKey
import dev.restate.sdk.kotlin.ObjectContext
import dev.restate.sdk.kotlin.SharedObjectContext
import kotlinx.serialization.Serializable
import org.apache.logging.log4j.LogManager
import org.apache.logging.log4j.Logger
Expand Down Expand Up @@ -40,7 +42,8 @@ class CounterKt {
}

@Handler
suspend fun get(ctx: ObjectContext): Long {
@Shared
suspend fun get(ctx: SharedObjectContext): Long {
return ctx.get(TOTAL) ?: 0L
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import dev.restate.sdk.Context;
import dev.restate.sdk.ObjectContext;
import dev.restate.sdk.SharedObjectContext;
import dev.restate.sdk.annotation.Exclusive;
import dev.restate.sdk.annotation.Shared;
import dev.restate.sdk.annotation.Workflow;
Expand Down Expand Up @@ -233,6 +234,8 @@ private void validateMethodSignature(
case SHARED:
if (serviceType == ServiceType.WORKFLOW) {
validateFirstParameterType(WorkflowSharedContext.class, element);
} else if (serviceType == ServiceType.VIRTUAL_OBJECT) {
validateFirstParameterType(SharedObjectContext.class, element);
} else {
messager.printMessage(
Diagnostic.Kind.ERROR,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ public class {{generatedClassSimpleName}} implements dev.restate.sdk.common.Bind
public {{generatedClassSimpleName}}({{originalClassFqcn}} bindableService, dev.restate.sdk.Service.Options options) {
this.service = dev.restate.sdk.Service.{{#if isObject}}virtualObject{{else}}service{{/if}}(SERVICE_NAME)
{{#handlers}}
.with(
.{{#if isShared}}withShared{{else if isExclusive}}withExclusive{{else}}with{{/if}}(
dev.restate.sdk.Service.HandlerSignature.of("{{name}}", {{{inputSerdeDecl}}}, {{{outputSerdeDecl}}}),
(ctx, req) -> {
{{#if outputEmpty}}
Expand Down
14 changes: 11 additions & 3 deletions sdk-api-gen/src/test/java/dev/restate/sdk/CodegenTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
import static dev.restate.sdk.core.TestDefinitions.testInvocation;

import com.google.protobuf.ByteString;
import dev.restate.sdk.annotation.Exclusive;
import dev.restate.sdk.annotation.Handler;
import dev.restate.sdk.annotation.*;
import dev.restate.sdk.annotation.Service;
import dev.restate.sdk.annotation.VirtualObject;
import dev.restate.sdk.common.CoreSerdes;
import dev.restate.sdk.common.Target;
import dev.restate.sdk.core.ProtoUtils;
Expand All @@ -39,6 +37,12 @@ static class ObjectGreeter {
String greet(ObjectContext context, String request) {
return request;
}

@Handler
@Shared
String sharedGreet(SharedObjectContext context, String request) {
return request;
}
}

@VirtualObject
Expand Down Expand Up @@ -113,6 +117,10 @@ public Stream<TestDefinitions.TestDefinition> definitions() {
.withInput(startMessage(1, "slinkydeveloper"), inputMessage("Francesco"))
.onlyUnbuffered()
.expectingOutput(outputMessage("Francesco"), END_MESSAGE),
testInvocation(ObjectGreeter::new, "sharedGreet")
.withInput(startMessage(1, "slinkydeveloper"), inputMessage("Francesco"))
.onlyUnbuffered()
.expectingOutput(outputMessage("Francesco"), END_MESSAGE),
testInvocation(ObjectGreeterImplementedFromInterface::new, "greet")
.withInput(startMessage(1, "slinkydeveloper"), inputMessage("Francesco"))
.onlyUnbuffered()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import dev.restate.sdk.gen.model.PayloadType
import dev.restate.sdk.gen.model.Service
import dev.restate.sdk.kotlin.Context
import dev.restate.sdk.kotlin.ObjectContext
import dev.restate.sdk.kotlin.SharedObjectContext
import java.util.regex.Pattern
import kotlin.reflect.KClass

Expand Down Expand Up @@ -128,7 +129,7 @@ class KElementConverter(private val logger: KSPLogger, private val builtIns: KSB
}

val isAnnotatedWithShared =
function.isAnnotationPresent(dev.restate.sdk.annotation.Service::class)
function.isAnnotationPresent(dev.restate.sdk.annotation.Shared::class)
val isAnnotatedWithExclusive =
function.isAnnotationPresent(dev.restate.sdk.annotation.Exclusive::class)

Expand Down Expand Up @@ -190,8 +191,13 @@ class KElementConverter(private val logger: KSPLogger, private val builtIns: KSB
}
when (handlerType) {
HandlerType.SHARED ->
logger.error(
"The annotation @Shared is not supported by the service type $serviceType", function)
if (serviceType == ServiceType.VIRTUAL_OBJECT) {
validateFirstParameterType(SharedObjectContext::class, function)
} else {
logger.error(
"The annotation @Shared is not supported by the service type $serviceType",
function)
}
HandlerType.EXCLUSIVE ->
if (serviceType == ServiceType.VIRTUAL_OBJECT) {
validateFirstParameterType(ObjectContext::class, function)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class {{generatedClassSimpleName}}(

val service: dev.restate.sdk.kotlin.Service = dev.restate.sdk.kotlin.Service.{{#if isObject}}virtualObject{{else}}service{{/if}}(SERVICE_NAME, options) {
{{#handlers}}
handler(dev.restate.sdk.kotlin.Service.HandlerSignature("{{name}}", {{{inputSerdeDecl}}}, {{{outputSerdeDecl}}})) { ctx, req ->
{{#if isShared}}sharedHandler{{else if isExclusive}}exclusiveHandler{{else}}handler{{/if}}(dev.restate.sdk.kotlin.Service.HandlerSignature("{{name}}", {{{inputSerdeDecl}}}, {{{outputSerdeDecl}}})) { ctx, req ->
{{#if inputEmpty}}bindableService.{{name}}(ctx){{else}}bindableService.{{name}}(ctx, req){{/if}}
}
{{/handlers}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
package dev.restate.sdk.kotlin

import com.google.protobuf.ByteString
import dev.restate.sdk.annotation.Exclusive
import dev.restate.sdk.annotation.Handler
import dev.restate.sdk.annotation.*
import dev.restate.sdk.annotation.Service
import dev.restate.sdk.annotation.VirtualObject
import dev.restate.sdk.common.CoreSerdes
import dev.restate.sdk.common.Target
import dev.restate.sdk.core.ProtoUtils.*
Expand All @@ -36,6 +34,12 @@ class CodegenTest : TestDefinitions.TestSuite {
suspend fun greet(context: ObjectContext, request: String): String {
return request
}

@Handler
@Shared
suspend fun sharedGreet(context: SharedObjectContext, request: String): String {
return request
}
}

@VirtualObject
Expand Down Expand Up @@ -104,6 +108,10 @@ class CodegenTest : TestDefinitions.TestSuite {
.withInput(startMessage(1, "slinkydeveloper"), inputMessage("Francesco"))
.onlyUnbuffered()
.expectingOutput(outputMessage("Francesco"), END_MESSAGE),
testInvocation({ ObjectGreeter() }, "sharedGreet")
.withInput(startMessage(1, "slinkydeveloper"), inputMessage("Francesco"))
.onlyUnbuffered()
.expectingOutput(outputMessage("Francesco"), END_MESSAGE),
testInvocation({ ObjectGreeterImplementedFromInterface() }, "greet")
.withInput(startMessage(1, "slinkydeveloper"), inputMessage("Francesco"))
.onlyUnbuffered()
Expand Down
30 changes: 21 additions & 9 deletions sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/Service.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@
package dev.restate.sdk.kotlin

import com.google.protobuf.ByteString
import dev.restate.sdk.common.BindableService
import dev.restate.sdk.common.Serde
import dev.restate.sdk.common.ServiceType
import dev.restate.sdk.common.TerminalException
import dev.restate.sdk.common.*
import dev.restate.sdk.common.syscalls.*
import kotlin.coroutines.CoroutineContext
import kotlinx.coroutines.CoroutineScope
Expand Down Expand Up @@ -65,18 +62,31 @@ private constructor(
class VirtualObjectBuilder internal constructor(private val name: String) {
private val handlers: MutableMap<String, Handler<*, *, ObjectContext>> = mutableMapOf()

fun <REQ, RES> handler(
fun <REQ, RES> sharedHandler(
sig: HandlerSignature<REQ, RES>,
runner: suspend (ObjectContext, REQ) -> RES
): VirtualObjectBuilder {
handlers[sig.name] = Handler(sig, runner)
handlers[sig.name] = Handler(sig, HandlerType.SHARED, runner)
return this
}

inline fun <reified REQ, reified RES> handler(
inline fun <reified REQ, reified RES> sharedHandler(
name: String,
noinline runner: suspend (ObjectContext, REQ) -> RES
) = this.handler(HandlerSignature(name, KtSerdes.json(), KtSerdes.json()), runner)
) = this.sharedHandler(HandlerSignature(name, KtSerdes.json(), KtSerdes.json()), runner)

fun <REQ, RES> exclusiveHandler(
sig: HandlerSignature<REQ, RES>,
runner: suspend (ObjectContext, REQ) -> RES
): VirtualObjectBuilder {
handlers[sig.name] = Handler(sig, HandlerType.EXCLUSIVE, runner)
return this
}

inline fun <reified REQ, reified RES> exclusiveHandler(
name: String,
noinline runner: suspend (ObjectContext, REQ) -> RES
) = this.exclusiveHandler(HandlerSignature(name, KtSerdes.json(), KtSerdes.json()), runner)

fun build(options: Options) = Service(this.name, true, this.handlers, options)
}
Expand All @@ -88,7 +98,7 @@ private constructor(
sig: HandlerSignature<REQ, RES>,
runner: suspend (Context, REQ) -> RES
): ServiceBuilder {
handlers[sig.name] = Handler(sig, runner)
handlers[sig.name] = Handler(sig, HandlerType.SHARED, runner)
return this
}

Expand All @@ -102,6 +112,7 @@ private constructor(

class Handler<REQ, RES, CTX : Context>(
private val handlerSignature: HandlerSignature<REQ, RES>,
private val handlerType: HandlerType,
private val runner: suspend (CTX, REQ) -> RES,
) : InvocationHandler<Options> {

Expand All @@ -112,6 +123,7 @@ private constructor(
fun toHandlerDefinition() =
HandlerDefinition(
handlerSignature.name,
handlerType,
handlerSignature.requestSerde.schema(),
handlerSignature.responseSerde.schema(),
this)
Expand Down
9 changes: 8 additions & 1 deletion sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ suspend inline fun <reified T : Any> Context.awakeable(): Awakeable<T> {
* This interface extends [Context] adding access to the virtual object instance key-value state
* storage.
*/
sealed interface ObjectContext : Context {
sealed interface SharedObjectContext : Context {

/** @return the key of this object */
fun key(): String
Expand All @@ -267,6 +267,13 @@ sealed interface ObjectContext : Context {
* @return the immutable collection of known state keys.
*/
suspend fun stateKeys(): Collection<String>
}

/**
* This interface extends [Context] adding access to the virtual object instance key-value state
* storage.
*/
sealed interface ObjectContext : SharedObjectContext {

/**
* Sets the given value under the given key, serializing the value using the [StateKey.serde].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class KotlinCoroutinesTests : TestRunner() {
): TestInvocationBuilder {
return TestDefinitions.testInvocation(
Service.virtualObject(name, Service.Options(Dispatchers.Unconfined)) {
handler("run", runner)
exclusiveHandler("run", runner)
},
"run")
}
Expand Down
28 changes: 1 addition & 27 deletions sdk-api/src/main/java/dev/restate/sdk/ObjectContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
package dev.restate.sdk;

import dev.restate.sdk.common.*;
import java.util.Collection;
import java.util.Optional;
import org.jspecify.annotations.NonNull;

/**
Expand All @@ -22,31 +20,7 @@
*
* @see Context
*/
public interface ObjectContext extends Context {

/**
* @return the key of this object
*/
String key();

/**
* Gets the state stored under key, deserializing the raw value using the {@link Serde} in the
* {@link StateKey}.
*
* @param key identifying the state to get and its type.
* @return an {@link Optional} containing the stored state deserialized or an empty {@link
* Optional} if not set yet.
* @throws RuntimeException when the state cannot be deserialized.
*/
<T> Optional<T> get(StateKey<T> key);

/**
* Gets all the known state keys for this virtual object instance.
*
* @return the immutable collection of known state keys.
*/
Collection<String> stateKeys();

public interface ObjectContext extends SharedObjectContext {
/**
* Clears the state stored under key.
*
Expand Down
16 changes: 13 additions & 3 deletions sdk-api/src/main/java/dev/restate/sdk/Service.java
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,15 @@ public static class VirtualObjectBuilder extends AbstractServiceBuilder {
super(name);
}

public <REQ, RES> VirtualObjectBuilder with(
public <REQ, RES> VirtualObjectBuilder withShared(
HandlerSignature<REQ, RES> sig, BiFunction<SharedObjectContext, REQ, RES> runner) {
this.handlers.put(sig.getName(), new Handler<>(sig, HandlerType.SHARED, runner));
return this;
}

public <REQ, RES> VirtualObjectBuilder withExclusive(
HandlerSignature<REQ, RES> sig, BiFunction<ObjectContext, REQ, RES> runner) {
this.handlers.put(sig.getName(), new Handler<>(sig, runner));
this.handlers.put(sig.getName(), new Handler<>(sig, HandlerType.EXCLUSIVE, runner));
return this;
}

Expand All @@ -90,7 +96,7 @@ public static class ServiceBuilder extends AbstractServiceBuilder {

public <REQ, RES> ServiceBuilder with(
HandlerSignature<REQ, RES> sig, BiFunction<Context, REQ, RES> runner) {
this.handlers.put(sig.getName(), new Handler<>(sig, runner));
this.handlers.put(sig.getName(), new Handler<>(sig, HandlerType.SHARED, runner));
return this;
}

Expand All @@ -102,14 +108,17 @@ public Service build(Service.Options options) {
@SuppressWarnings("unchecked")
public static class Handler<REQ, RES> implements InvocationHandler<Service.Options> {
private final HandlerSignature<REQ, RES> handlerSignature;
private final HandlerType handlerType;
private final BiFunction<Context, REQ, RES> runner;

private static final Logger LOG = LogManager.getLogger(Handler.class);

public Handler(
HandlerSignature<REQ, RES> handlerSignature,
HandlerType handlerType,
BiFunction<? extends Context, REQ, RES> runner) {
this.handlerSignature = handlerSignature;
this.handlerType = handlerType;
this.runner = (BiFunction<Context, REQ, RES>) runner;
}

Expand All @@ -124,6 +133,7 @@ public BiFunction<Context, REQ, RES> getRunner() {
public HandlerDefinition<Service.Options> toHandlerDefinition() {
return new HandlerDefinition<>(
this.handlerSignature.name,
this.handlerType,
this.handlerSignature.requestSerde.schema(),
this.handlerSignature.responseSerde.schema(),
this);
Expand Down
Loading
Loading