Skip to content

Commit

Permalink
Add observability to moonshot chat model
Browse files Browse the repository at this point in the history
  • Loading branch information
mxsl-gr authored and markpollack committed Oct 7, 2024
1 parent ee7d2b2 commit 1606383
Show file tree
Hide file tree
Showing 7 changed files with 340 additions and 85 deletions.
6 changes: 6 additions & 0 deletions models/spring-ai-moonshot/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>io.micrometer</groupId>
<artifactId>micrometer-observation-test</artifactId>
<scope>test</scope>
</dependency>

</dependencies>

</project>
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,29 @@
*/
package org.springframework.ai.moonshot;

import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationRegistry;
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.RateLimit;
import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.chat.model.AbstractToolCallSupport;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.MessageAggregator;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.observation.ChatModelObservationContext;
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
Expand All @@ -44,6 +53,7 @@
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionMessage.ToolCall;
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletionRequest;
import org.springframework.ai.moonshot.api.MoonshotApi.FunctionTool;
import org.springframework.ai.moonshot.api.MoonshotConstants;
import org.springframework.ai.moonshot.metadata.MoonshotUsage;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
Expand All @@ -66,6 +76,8 @@ public class MoonshotChatModel extends AbstractToolCallSupport implements ChatMo

private static final Logger logger = LoggerFactory.getLogger(MoonshotChatModel.class);

private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();

/**
* The default options used for the chat completion requests.
*/
Expand All @@ -78,6 +90,16 @@ public class MoonshotChatModel extends AbstractToolCallSupport implements ChatMo

private final RetryTemplate retryTemplate;

/**
* Observation registry used for instrumentation.
*/
private final ObservationRegistry observationRegistry;

/**
* Conventions to use for generating observations.
*/
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

/**
* Initializes a new instance of the MoonshotChatModel.
* @param moonshotApi The Moonshot instance to be used for interacting with the
Expand Down Expand Up @@ -107,7 +129,7 @@ public MoonshotChatModel(MoonshotApi moonshotApi, MoonshotChatOptions options) {
*/
public MoonshotChatModel(MoonshotApi moonshotApi, MoonshotChatOptions options,
FunctionCallbackContext functionCallbackContext, RetryTemplate retryTemplate) {
this(moonshotApi, options, functionCallbackContext, List.of(), retryTemplate);
this(moonshotApi, options, functionCallbackContext, List.of(), retryTemplate, ObservationRegistry.NOOP);
}

/**
Expand All @@ -118,63 +140,81 @@ public MoonshotChatModel(MoonshotApi moonshotApi, MoonshotChatOptions options,
* @param functionCallbackContext The function callback context.
* @param toolFunctionCallbacks The tool function callbacks.
* @param retryTemplate The retry template.
* @param observationRegistry The ObservationRegistry used for instrumentation.
*/
public MoonshotChatModel(MoonshotApi moonshotApi, MoonshotChatOptions options,
FunctionCallbackContext functionCallbackContext, List<FunctionCallback> toolFunctionCallbacks,
RetryTemplate retryTemplate) {
RetryTemplate retryTemplate, ObservationRegistry observationRegistry) {
super(functionCallbackContext, options, toolFunctionCallbacks);
Assert.notNull(moonshotApi, "MoonshotApi must not be null");
Assert.notNull(options, "Options must not be null");
Assert.notNull(retryTemplate, "RetryTemplate must not be null");
Assert.isTrue(CollectionUtils.isEmpty(options.getFunctionCallbacks()),
"The default function callbacks must be set via the toolFunctionCallbacks constructor parameter");
Assert.notNull(observationRegistry, "ObservationRegistry must not be null");
this.moonshotApi = moonshotApi;
this.defaultOptions = options;
this.retryTemplate = retryTemplate;
this.observationRegistry = observationRegistry;
}

@Override
public ChatResponse call(Prompt prompt) {
ChatCompletionRequest request = createRequest(prompt, false);

ResponseEntity<ChatCompletion> completionEntity = this.retryTemplate
.execute(ctx -> this.moonshotApi.chatCompletionEntity(request));
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
.provider(MoonshotConstants.PROVIDER_NAME)
.requestOptions(buildRequestOptions(request))
.build();

var chatCompletion = completionEntity.getBody();
ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry)
.observe(() -> {
ResponseEntity<ChatCompletion> completionEntity = this.retryTemplate
.execute(ctx -> this.moonshotApi.chatCompletionEntity(request));

if (chatCompletion == null) {
logger.warn("No chat completion returned for prompt: {}", prompt);
return new ChatResponse(List.of());
}
var chatCompletion = completionEntity.getBody();

List<Choice> choices = chatCompletion.choices();
if (choices == null) {
logger.warn("No choices returned for prompt: {}", prompt);
return new ChatResponse(List.of());
}
if (chatCompletion == null) {
logger.warn("No chat completion returned for prompt: {}", prompt);
return new ChatResponse(List.of());
}

List<Choice> choices = chatCompletion.choices();
if (choices == null) {
logger.warn("No choices returned for prompt: {}", prompt);
return new ChatResponse(List.of());
}

List<Generation> generations = choices.stream().map(choice -> {
List<Generation> generations = choices.stream().map(choice -> {
// @formatter:off
Map<String, Object> metadata = Map.of(
"id", chatCompletion.id(),
"role", choice.message().role() != null ? choice.message().role().name() : "",
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "");
// @formatter:on
return buildGeneration(choice, metadata);
}).toList();
Map<String, Object> metadata = Map.of(
"id", chatCompletion.id(),
"role", choice.message().role() != null ? choice.message().role().name() : "",
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : ""
);
// @formatter:on
return buildGeneration(choice, metadata);
}).toList();

ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody()));

ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody()));
observationContext.setResponse(chatResponse);

return chatResponse;
});

if (!isProxyToolCalls(prompt, this.defaultOptions)
&& isToolCall(chatResponse, Set.of(MoonshotApi.ChatCompletionFinishReason.TOOL_CALLS.name(),
&& isToolCall(response, Set.of(MoonshotApi.ChatCompletionFinishReason.TOOL_CALLS.name(),
MoonshotApi.ChatCompletionFinishReason.STOP.name()))) {
var toolCallConversation = handleToolCalls(prompt, chatResponse);
var toolCallConversation = handleToolCalls(prompt, response);
// Recursively call the call method with the tool call message
// conversation that contains the call responses.
return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
}

return chatResponse;
return response;
}

@Override
Expand All @@ -184,83 +224,84 @@ public ChatOptions getDefaultOptions() {

@Override
public Flux<ChatResponse> stream(Prompt prompt) {
ChatCompletionRequest request = createRequest(prompt, true);
return Flux.deferContextual(contextView -> {
ChatCompletionRequest request = createRequest(prompt, true);

Flux<ChatCompletionChunk> completionChunks = this.retryTemplate
.execute(ctx -> this.moonshotApi.chatCompletionStream(request));
Flux<ChatCompletionChunk> completionChunks = this.retryTemplate
.execute(ctx -> this.moonshotApi.chatCompletionStream(request));

// For chunked responses, only the first chunk contains the choice role.
// The rest of the chunks with same ID share the same role.
ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>();
// For chunked responses, only the first chunk contains the choice role.
// The rest of the chunks with same ID share the same role.
ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>();

// Convert the ChatCompletionChunk into a ChatCompletion to be able to reuse
// the function call handling logic.
Flux<ChatResponse> chatResponse = completionChunks.map(this::chunkToChatCompletion)
.switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> {
try {
@SuppressWarnings("null")
String id = chatCompletion2.id();
final ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
.provider(MoonshotConstants.PROVIDER_NAME)
.requestOptions(buildRequestOptions(request))
.build();

Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry);

observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();

// Convert the ChatCompletionChunk into a ChatCompletion to be able to reuse
// the function call handling logic.
Flux<ChatResponse> chatResponse = completionChunks.map(this::chunkToChatCompletion)
.switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> {
try {
String id = chatCompletion2.id();

// @formatter:off
List<Generation> generations = chatCompletion2.choices().stream().map(choice -> {
if (choice.message().role() != null) {
roleMap.putIfAbsent(id, choice.message().role().name());
}

// @formatter:off
Map<String, Object> metadata = Map.of(
"id", chatCompletion2.id(),
"role", roleMap.getOrDefault(id, ""),
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "");
"id", chatCompletion2.id(),
"role", roleMap.getOrDefault(id, ""),
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : ""
);
// @formatter:on
return buildGeneration(choice, metadata);
}).toList();
// @formatter:on

if (chatCompletion2.usage() != null) {
return new ChatResponse(generations, from(chatCompletion2));
}
else {
return new ChatResponse(generations);
catch (Exception e) {
logger.error("Error processing chat completion", e);
return new ChatResponse(List.of());
}
}
catch (Exception e) {
logger.error("Error processing chat completion", e);
return new ChatResponse(List.of());
}

}));

return chatResponse.flatMap(response -> {
}));

if (!isProxyToolCalls(prompt, this.defaultOptions)
&& isToolCall(response, Set.of(ChatCompletionFinishReason.TOOL_CALLS.name(), "stop"))) {
var toolCallConversation = handleToolCalls(prompt, response);
// Recursively call the stream method with the tool call message
// conversation that contains the call responses.
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
}
else {
Flux<ChatResponse> flux = chatResponse.flatMap(response -> {
if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(response,
Set.of(ChatCompletionFinishReason.TOOL_CALLS.name(), ChatCompletionFinishReason.STOP.name()))) {
var toolCallConversation = handleToolCalls(prompt, response);
// Recursively call the stream method with the tool call message
// conversation that contains the call responses.
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
}
return Flux.just(response);
}
});
}
})
.doOnError(observation::error)
.doFinally(signalType -> observation.stop())
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));

private ChatResponseMetadata from(ChatCompletion result, RateLimit rateLimit) {
Assert.notNull(result, "Moonshot ChatCompletionResult must not be null");
return ChatResponseMetadata.builder()
.withId(result.id())
.withUsage(MoonshotUsage.from(result.usage()))
.withModel(result.model())
.withRateLimit(rateLimit)
.withKeyValue("created", result.created())
.build();
return new MessageAggregator().aggregate(flux, observationContext::setResponse);
});
}

private ChatResponseMetadata from(ChatCompletion result) {
Assert.notNull(result, "Moonshot ChatCompletionResult must not be null");
return ChatResponseMetadata.builder()
.withId(result.id())
.withUsage(MoonshotUsage.from(result.usage()))
.withModel(result.model())
.withKeyValue("created", result.created())
.withId(result.id() != null ? result.id() : "")
.withUsage(result.usage() != null ? MoonshotUsage.from(result.usage()) : new EmptyUsage())
.withModel(result.model() != null ? result.model() : "")
.withKeyValue("created", result.created() != null ? result.created() : 0L)
.build();
}

Expand Down Expand Up @@ -374,6 +415,18 @@ else if (message.getMessageType() == MessageType.TOOL) {
return request;
}

private ChatOptions buildRequestOptions(MoonshotApi.ChatCompletionRequest request) {
return ChatOptionsBuilder.builder()
.withModel(request.model())
.withFrequencyPenalty(request.frequencyPenalty())
.withMaxTokens(request.maxTokens())
.withPresencePenalty(request.presencePenalty())
.withStopSequences(request.stop())
.withTemperature(request.temperature())
.withTopP(request.topP())
.build();
}

private List<FunctionTool> getFunctionTools(Set<String> functionNames) {
return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> {
var function = new FunctionTool.Function(functionCallback.getDescription(), functionCallback.getName(),
Expand All @@ -382,4 +435,8 @@ private List<FunctionTool> getFunctionTools(Set<String> functionNames) {
}).toList();
}

public void setObservationConvention(ChatModelObservationConvention observationConvention) {
this.observationConvention = observationConvention;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@
*/
package org.springframework.ai.moonshot.api;

import org.springframework.ai.observation.conventions.AiProvider;

/**
* @author Geng Rong
*/
public final class MoonshotConstants {

public static final String DEFAULT_BASE_URL = "https://api.moonshot.cn";

public static final String PROVIDER_NAME = AiProvider.MOONSHOT.value();

}
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ public void moonshotChatStreamTransientError() {
public void moonshotChatStreamNonTransientError() {
when(moonshotApi.chatCompletionStream(isA(ChatCompletionRequest.class)))
.thenThrow(new RuntimeException("Non Transient Error"));
assertThrows(RuntimeException.class, () -> chatModel.stream(new Prompt("text")));
assertThrows(RuntimeException.class, () -> chatModel.stream(new Prompt("text")).collectList().block());
}

}
Loading

0 comments on commit 1606383

Please sign in to comment.