From 8465bb3007a3a95f3aefe73c92a211ee54bca484 Mon Sep 17 00:00:00 2001 From: francis2tm Date: Fri, 17 Jan 2025 18:45:05 +0100 Subject: [PATCH] feat: add confidential chat completions streaming endpoint and response schema - Introduced a new POST endpoint at /v1/confidential/chat/completions#stream for streaming chat completions. - Added the ConfidentialComputeStreamResponse schema to represent the streaming response structure. - Updated OpenAPI documentation to include the new endpoint and its associated request/response details. - Modified existing handlers to accommodate the new streaming functionality while maintaining backward compatibility. --- atoma-proxy/docs/openapi.yml | 37 +++++++++++++++++ atoma-proxy/src/server/components/openapi.rs | 4 ++ .../src/server/handlers/chat_completions.rs | 40 +++++++++++++++++-- atoma-proxy/src/server/types.rs | 7 ++++ 4 files changed, 85 insertions(+), 3 deletions(-) diff --git a/atoma-proxy/docs/openapi.yml b/atoma-proxy/docs/openapi.yml index ffd4875..2bb65b3 100644 --- a/atoma-proxy/docs/openapi.yml +++ b/atoma-proxy/docs/openapi.yml @@ -136,6 +136,32 @@ paths: description: Internal server error security: - bearerAuth: [] + /v1/confidential/chat/completions#stream: + post: + tags: + - Confidential Chat + operationId: confidential_chat_completions_create_stream + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/ConfidentialComputeRequest' + required: true + responses: + '200': + description: Chat completions + content: + text/event-stream: + schema: + $ref: '#/components/schemas/ConfidentialComputeStreamResponse' + '400': + description: Bad request + '401': + description: Unauthorized + '500': + description: Internal server error + security: + - bearerAuth: [] /v1/confidential/embeddings: post: tags: @@ -743,6 +769,15 @@ components: - type: 'null' - $ref: '#/components/schemas/Usage' description: Usage statistics for the request + ConfidentialComputeStreamResponse: + type: object + description: Represents a response from a confidential compute request + required: + - data + properties: + data: + $ref: '#/components/schemas/ConfidentialComputeResponse' + description: The stream of chat completion chunks. CreateChatCompletionRequest: allOf: - $ref: '#/components/schemas/ChatCompletionRequest' @@ -1128,6 +1163,8 @@ x-speakeasy-name-override: methodNameOverride: create_stream - operationId: confidential_chat_completions_create methodNameOverride: create +- operationId: confidential_chat_completions_create_stream + methodNameOverride: create_stream - operationId: embeddings_create methodNameOverride: create - operationId: confidential_embeddings_create diff --git a/atoma-proxy/src/server/components/openapi.rs b/atoma-proxy/src/server/components/openapi.rs index 1f4ea22..9298c1a 100644 --- a/atoma-proxy/src/server/components/openapi.rs +++ b/atoma-proxy/src/server/components/openapi.rs @@ -85,6 +85,10 @@ pub fn openapi_routes() -> Router { "operationId": "confidential_chat_completions_create", "methodNameOverride": "create" }, + { + "operationId": "confidential_chat_completions_create_stream", + "methodNameOverride": "create_stream" + }, { "operationId": "embeddings_create", "methodNameOverride": "create" diff --git a/atoma-proxy/src/server/handlers/chat_completions.rs b/atoma-proxy/src/server/handlers/chat_completions.rs index 5f4026c..f5ae0f9 100644 --- a/atoma-proxy/src/server/handlers/chat_completions.rs +++ b/atoma-proxy/src/server/handlers/chat_completions.rs @@ -1,6 +1,6 @@ use std::time::{Duration, Instant}; -use crate::server::types::ConfidentialComputeResponse; +use crate::server::types::{ConfidentialComputeResponse, ConfidentialComputeStreamResponse}; use crate::server::{ error::AtomaProxyError, http_server::ProxyState, middleware::RequestMetadataExtension, streamer::Streamer, types::ConfidentialComputeRequest, @@ -240,7 +240,7 @@ pub async fn chat_completions_create_stream( // This endpoint exists only for OpenAPI documentation // Actual streaming is handled by chat_completions_create Err(AtomaProxyError::NotImplemented { - message: "Streaming is not implemented".to_string(), + message: "This is a mock endpoint for OpenAPI documentation".to_string(), endpoint: CHAT_COMPLETIONS_PATH.to_string(), }) } @@ -268,7 +268,10 @@ pub async fn chat_completions_create_stream( /// * `ChatCompletionChunkDelta` - Incremental updates in streaming responses #[derive(OpenApi)] #[openapi( - paths(confidential_chat_completions_create), + paths( + confidential_chat_completions_create, + confidential_chat_completions_create_stream + ), components(schemas(ConfidentialComputeRequest)) )] pub(crate) struct ConfidentialChatCompletionsOpenApi; @@ -350,6 +353,37 @@ pub async fn confidential_chat_completions_create( } } +#[utoipa::path( + post, + path = "#stream", + security( + ("bearerAuth" = []) + ), + request_body = ConfidentialComputeRequest, + responses( + (status = OK, description = "Chat completions", content( + (ConfidentialComputeStreamResponse = "text/event-stream") + )), + (status = BAD_REQUEST, description = "Bad request"), + (status = UNAUTHORIZED, description = "Unauthorized"), + (status = INTERNAL_SERVER_ERROR, description = "Internal server error") + ) +)] +#[allow(dead_code)] +pub async fn confidential_chat_completions_create_stream( + Extension(_metadata): Extension, + State(_state): State, + _headers: HeaderMap, + Json(_payload): Json, +) -> Result> { + // This endpoint exists only for OpenAPI documentation + // Actual streaming is handled by chat_completions_create + Err(AtomaProxyError::NotImplemented { + message: "This is a mock endpoint for OpenAPI documentation".to_string(), + endpoint: CHAT_COMPLETIONS_PATH.to_string(), + }) +} + /// Handles non-streaming chat completion requests by processing them through the inference service. /// /// This function performs several key operations: diff --git a/atoma-proxy/src/server/types.rs b/atoma-proxy/src/server/types.rs index 8bd7e04..9a057fa 100644 --- a/atoma-proxy/src/server/types.rs +++ b/atoma-proxy/src/server/types.rs @@ -59,6 +59,13 @@ pub struct ConfidentialComputeResponse { pub usage: Option, } +/// Represents a response from a confidential compute request +#[derive(Debug, Deserialize, Serialize, ToSchema)] +pub struct ConfidentialComputeStreamResponse { + /// The stream of chat completion chunks. + pub data: ConfidentialComputeResponse, +} + /// Represents usage statistics for a confidential compute request #[derive(Debug, Deserialize, Serialize, ToSchema)] pub struct Usage {