Skip to content

Commit

Permalink
Adds grpc plugin (#2388)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored Sep 16, 2024
1 parent 5c46335 commit 625b5bd
Show file tree
Hide file tree
Showing 19 changed files with 828 additions and 7 deletions.
4 changes: 2 additions & 2 deletions buildSrc/src/main/kotlin/ai/djl/javaFormatter.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ tasks {
val formatter = Main(PrintWriter(System.out, true), PrintWriter(System.err, true), System.`in`)
for (item in project.sourceSets)
for (file in item.allSource) {
if (!file.name.endsWith(".java") || "generated-src" in file.absolutePath)
if (!file.name.endsWith(".java") || "generated" in file.absolutePath)
continue
if (formatter.format("-a", "-i", file.absolutePath) != 0)
throw GradleException("Format java failed: " + file.absolutePath)
Expand All @@ -28,7 +28,7 @@ tasks {
val formatter = Main(PrintWriter(System.out, true), PrintWriter(System.err, true), System.`in`)
for (item in project.sourceSets)
for (file in item.allSource) {
if (!file.name.endsWith(".java") || "generated-src" in file.absolutePath)
if (!file.name.endsWith(".java") || "generated" in file.absolutePath)
continue
if (formatter.format("-a", "-n", "--set-exit-if-changed", file.absolutePath) != 0)
throw GradleException(
Expand Down
3 changes: 3 additions & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ prometheus = "1.3.1"
awssdk = "2.26.19"
httpcomponents = "4.5.14"
jsonpath = "2.9.0"
grpc = "1.66.0"
annotationsApi = "6.0.53"
protoc = "3.25.3"

testng = "7.10.2"
junit = "4.13.2"
Expand Down
10 changes: 10 additions & 0 deletions jacoco/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dependencies {
jacocoAggregation(project(":benchmark"))
jacocoAggregation(project(":engines:python"))
jacocoAggregation(project(":plugins:cache"))
jacocoAggregation(project(":plugins:grpc"))
jacocoAggregation(project(":plugins:kserve"))
// jacocoAggregation(project(":plugins:management-console"))
jacocoAggregation(project(":plugins:plugin-management-plugin"))
Expand All @@ -33,6 +34,15 @@ tasks {
}
}

val testCodeCoverageReport by getting(JacocoReport::class) {
classDirectories.setFrom(files(classDirectories.files.map {
fileTree(it) {
exclude(
"ai/djl/serving/grpc/proto/**"
)
}
}))
}
check {
dependsOn("testCodeCoverageReport")
}
Expand Down
64 changes: 64 additions & 0 deletions plugins/grpc/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
plugins {
ai.djl.javaProject
id("com.google.protobuf") version "0.9.4"
}

val exclusion by configurations.registering

@Suppress("UnstableApiUsage")
dependencies {
api(platform("ai.djl:bom:${version}"))
implementation(project(":serving"))
implementation("io.grpc:grpc-netty-shaded:${libs.versions.grpc.get()}")
implementation("io.grpc:grpc-protobuf:${libs.versions.grpc.get()}")
implementation("io.grpc:grpc-stub:${libs.versions.grpc.get()}")
implementation("io.grpc:protoc-gen-grpc-java:${libs.versions.grpc.get()}")
// necessary for Java 9+
compileOnly("org.apache.tomcat:annotations-api:${libs.versions.annotationsApi.get()}")

testImplementation(libs.commons.cli)
testImplementation(libs.testng) {
exclude(group = "junit", module = "junit")
}

exclusion(project(":serving"))
exclusion("com.google.code.gson:gson")
}

protobuf {
protoc {
artifact = "com.google.protobuf:protoc:${libs.versions.protoc.get()}"
}
plugins {
create("grpc") {
artifact = "io.grpc:protoc-gen-grpc-java:${libs.versions.grpc.get()}"
}
}
generateProtoTasks {
all().forEach {
it.plugins {
create("grpc")
}
}
}
}

tasks {
processResources {
dependsOn(generateProto)
}

jar {
includeEmptyDirs = false
duplicatesStrategy = DuplicatesStrategy.INCLUDE
from((configurations.runtimeClasspath.get() - exclusion.get()).map {
if (it.isDirectory()) it else zipTree(it)
})
}

verifyJava {
dependsOn(generateProto)
}
checkstyleMain { exclude("ai/djl/serving/grpc/proto/*") }
pmdMain { exclude("ai/djl/serving/grpc/proto/*") }
}
119 changes: 119 additions & 0 deletions plugins/grpc/src/main/java/ai/djl/serving/grpc/GrpcClient.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.serving.grpc;

import ai.djl.serving.grpc.proto.InferenceGrpc;
import ai.djl.serving.grpc.proto.InferenceRequest;
import ai.djl.serving.grpc.proto.InferenceResponse;
import ai.djl.serving.grpc.proto.PingResponse;

import com.google.protobuf.ByteString;
import com.google.protobuf.Empty;

import io.grpc.Grpc;
import io.grpc.InsecureChannelCredentials;
import io.grpc.ManagedChannel;

import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.TimeUnit;

/** The gRPC client that connect to server. */
public class GrpcClient implements AutoCloseable {

private ManagedChannel channel;
private InferenceGrpc.InferenceBlockingStub stub;

/**
* Constructs client for accessing HelloWorld server using the existing channel.
*
* @param channel the managed channel
*/
public GrpcClient(ManagedChannel channel) {
this.channel = channel;
stub = InferenceGrpc.newBlockingStub(channel);
}

/**
* Constructs a new instance with target address.
*
* @param target the target address
* @return a new instance
*/
public static GrpcClient newInstance(String target) {
ManagedChannel channel =
Grpc.newChannelBuilder(target, InsecureChannelCredentials.create()).build();
return new GrpcClient(channel);
}

/** {@inheritDoc} */
@Override
public void close() {
try {
channel.shutdownNow().awaitTermination(5, TimeUnit.SECONDS);
} catch (InterruptedException ignore) {
// ignore
}
}

/**
* Sends the {@code Ping} command to the server.
*
* @return the ping response
*/
public PingResponse ping() {
Empty request = Empty.getDefaultInstance();
return stub.ping(request);
}

/**
* Sends the {@code Ping} command to the server.
*
* @param modelName the model name
* @param data the inference payload
* @return the inference responses
*/
public Iterator<InferenceResponse> inference(String modelName, String data) {
return inference(modelName, null, Collections.emptyMap(), data);
}

/**
* Sends the {@code Ping} command to the server.
*
* @param modelName the model name
* @param version the model version
* @param data the inference payload
* @param headers the input headers
* @return the inference responses
*/
public Iterator<InferenceResponse> inference(
String modelName, String version, Map<String, String> headers, String data) {
InferenceRequest.Builder builder = InferenceRequest.newBuilder();
if (modelName != null) {
builder.setModelName(modelName);
}
if (version != null) {
builder.setModelVersion(version);
}
for (Map.Entry<String, String> entry : headers.entrySet()) {
ByteString value = ByteString.copyFrom(entry.getValue(), StandardCharsets.UTF_8);
builder.putHeaders(entry.getKey(), value);
}

ByteString input = ByteString.copyFrom(data, StandardCharsets.UTF_8);
InferenceRequest req = builder.setInput(input).build();
return stub.predict(req);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.serving.grpc;

import ai.djl.serving.http.Session;

import io.grpc.ForwardingServerCall;
import io.grpc.Grpc;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

class GrpcInterceptor implements ServerInterceptor {

private static final Logger logger = LoggerFactory.getLogger("ACCESS_LOG");

/** {@inheritDoc} */
@Override
public <I, O> ServerCall.Listener<I> interceptCall(
ServerCall<I, O> call, Metadata headers, ServerCallHandler<I, O> next) {
String ip = call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR).toString();
String serviceName = call.getMethodDescriptor().getFullMethodName();
Session session = new Session(ip, serviceName);

return next.startCall(
new ForwardingServerCall.SimpleForwardingServerCall<>(call) {

/** {@inheritDoc} */
@Override
public void close(final Status status, final Metadata trailers) {
session.setCode(status.getCode().value());
logger.info(session.toString());
super.close(status, trailers);
}
},
headers);
}
}
74 changes: 74 additions & 0 deletions plugins/grpc/src/main/java/ai/djl/serving/grpc/GrpcServerImpl.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.serving.grpc;

import ai.djl.serving.GrpcServer;
import ai.djl.serving.util.ConfigManager;

import io.grpc.Server;
import io.grpc.ServerBuilder;
import io.grpc.ServerInterceptors;
import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.concurrent.TimeUnit;

/** {@link GrpcServer} implementation. */
public class GrpcServerImpl extends GrpcServer {

private static final Logger logger = LoggerFactory.getLogger(GrpcServerImpl.class);

private Server server;

/** {@inheritDoc} */
@Override
public void start() throws IOException {
ConfigManager configManager = ConfigManager.getInstance();
String ip = configManager.getProperty("grpc_address", "127.0.0.1");
int port = configManager.getIntProperty("grpc_port", 8082);
InetSocketAddress address = new InetSocketAddress(ip, port);
long maxConnectionAge =
configManager.getIntProperty("grpc_max_connection_age", Integer.MAX_VALUE);
long maxConnectionGrace =
configManager.getIntProperty("grpc_max_connection_grace", Integer.MAX_VALUE);

ServerBuilder<?> s =
NettyServerBuilder.forAddress(address)
.maxConnectionAge(maxConnectionAge, TimeUnit.MILLISECONDS)
.maxConnectionAgeGrace(maxConnectionGrace, TimeUnit.MILLISECONDS)
.maxInboundMessageSize(configManager.getMaxRequestSize())
.addService(
ServerInterceptors.intercept(
new InferenceService(), new GrpcInterceptor()));

server = s.build();
server.start();
logger.info("gRPC bind to port: {}:{}", ip, port);
}

/** {@inheritDoc} */
@Override
public void stop() {
if (server != null) {
try {
server.shutdown().awaitTermination(30, TimeUnit.SECONDS);
} catch (InterruptedException e) {
logger.warn("Stop gPRC server failed", e);
}
}
}
}
Loading

0 comments on commit 625b5bd

Please sign in to comment.