Skip to content

Commit

Permalink
Add Azure OpenAI Content Filter Support
Browse files Browse the repository at this point in the history
Azure enforces content filtering on all completion requests.
To reduce the overhead of content filtering, they’ve added asychronous mode, which basically outputs specialized bodies at the end of streaming output.

https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/content-filter?tabs=warning%2Cpython-new#annotation-message

The basic structure is this:
data: {"choices":[{"content_filter_offsets":{"check_offset":33188,"start_offset":33188,"end_offset":33546},"content_filter_results":{"hate":{"filtered":false,"severity":"safe"},"self_harm":{"filtered":false,"severity":"safe"},"sexual":{"filtered":false,"severity":"safe"},"violence":{"filtered":false,"severity":"safe"}},"finish_reason":null,"index":0}],"created":0,"id":"","model":"","object":""}
  • Loading branch information
rasharab committed May 28, 2024
1 parent a2ce122 commit 5883420
Show file tree
Hide file tree
Showing 13 changed files with 224 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ internal class ChatMessageAssembler {
private val chatContent = StringBuilder()
private var chatRole: ChatRole? = null
private val toolCallsAssemblers = mutableMapOf<Int, ToolCallAssembler>()
private var chatContentFilterOffsets = mutableListOf<ContentFilterOffsets>()
private var chatContentFilterResults = mutableListOf<ContentFilterResults>()

/**
* Merges a chat chunk into the chat message being assembled.
*/
fun merge(chunk: ChatChunk): ChatMessageAssembler {
chunk.delta.run {
chunk.delta?.run {
role?.let { chatRole = it }
content?.let { chatContent.append(it) }
functionCall?.let { call ->
Expand All @@ -30,6 +32,12 @@ internal class ChatMessageAssembler {
assembler.merge(toolCall)
}
}
chunk.contentFilterOffsets?.also {
chatContentFilterOffsets.add(it)
}
chunk.contentFilterResults?.also {
chatContentFilterResults.add(it)
}
return this
}

Expand All @@ -39,6 +47,8 @@ internal class ChatMessageAssembler {
fun build(): ChatMessage = chatMessage {
this.role = chatRole
this.content = chatContent.toString()
this.contentFilterOffsets = chatContentFilterOffsets
this.contentFilterResults = chatContentFilterResults
if (chatFuncName.isNotEmpty() || chatFuncArgs.isNotEmpty()) {
this.functionCall = FunctionCall(chatFuncName.toString(), chatFuncArgs.toString())
this.name = chatFuncName.toString()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import com.aallam.openai.api.chat.ChatChunk
import com.aallam.openai.api.chat.ChatDelta
import com.aallam.openai.api.chat.ChatMessage
import com.aallam.openai.api.chat.ChatRole
import com.aallam.openai.api.chat.ContentFilterOffsets
import com.aallam.openai.api.chat.ContentFilterResult
import com.aallam.openai.api.chat.ContentFilterResults
import com.aallam.openai.api.core.FinishReason
import com.aallam.openai.client.extension.mergeToChatMessage
import kotlin.test.Test
Expand All @@ -20,6 +23,8 @@ class TestChatChunk {
role = ChatRole(role = "assistant"),
content = ""
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -28,6 +33,8 @@ class TestChatChunk {
role = null,
content = "The"
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -36,6 +43,8 @@ class TestChatChunk {
role = null,
content = " World"
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -44,6 +53,8 @@ class TestChatChunk {
role = null,
content = " Series"
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -52,6 +63,8 @@ class TestChatChunk {
role = null,
content = " in"
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -60,6 +73,8 @@ class TestChatChunk {
role = null,
content = " "
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -68,6 +83,8 @@ class TestChatChunk {
role = null,
content = "202"
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -76,6 +93,8 @@ class TestChatChunk {
role = null,
content = "0"
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -84,6 +103,8 @@ class TestChatChunk {
role = null,
content = " is"
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -92,6 +113,8 @@ class TestChatChunk {
role = null,
content = " being held"
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -100,6 +123,8 @@ class TestChatChunk {
role = null,
content = " in"
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -108,6 +133,8 @@ class TestChatChunk {
role = null,
content = " Texas"
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -116,6 +143,8 @@ class TestChatChunk {
role = null,
content = "."
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = null
),
ChatChunk(
Expand All @@ -124,6 +153,24 @@ class TestChatChunk {
role = null,
content = null
),
contentFilterOffsets = null,
contentFilterResults = null,
finishReason = FinishReason(value = "stop")
),
ChatChunk(
index = 0,
delta = null,
contentFilterOffsets = ContentFilterOffsets(
checkOffset = 1,
startOffset = 1,
endOffset = 1,
),
contentFilterResults = ContentFilterResults(
hate = ContentFilterResult(
filtered = false,
severity = "high",
)
),
finishReason = FinishReason(value = "stop")
)
)
Expand All @@ -132,6 +179,21 @@ class TestChatChunk {
role = ChatRole.Assistant,
content = "The World Series in 2020 is being held in Texas.",
name = null,
contentFilterResults = listOf(
ContentFilterResults(
hate = ContentFilterResult(
filtered = false,
severity = "high",
)
)
),
contentFilterOffsets = listOf(
ContentFilterOffsets(
checkOffset = 1,
startOffset = 1,
endOffset = 1,
)
),
)
assertEquals(chatMessage, message)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package com.aallam.openai.client

import com.aallam.openai.api.chat.ChatCompletionChunk
import com.aallam.openai.api.file.FileSource
import com.aallam.openai.client.internal.JsonLenient
import com.aallam.openai.client.internal.TestFileSystem
import com.aallam.openai.client.internal.testFilePath
import kotlin.test.Test
import okio.buffer

class TestChatCompletionChunk {
@Test
fun testContentFilterDeserialization() {
val json = FileSource(path = testFilePath("json/azureContentFilterChunk.json"), fileSystem = TestFileSystem)
val actualJson = json.source.buffer().readByteArray().decodeToString()
JsonLenient.decodeFromString<ChatCompletionChunk>(actualJson)
}

@Test
fun testDeserialization() {
val json = FileSource(path = testFilePath("json/chatChunk.json"), fileSystem = TestFileSystem)
val actualJson = json.source.buffer().readByteArray().decodeToString()
JsonLenient.decodeFromString<ChatCompletionChunk>(actualJson)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
{
"choices": [
{
"content_filter_offsets": {
"check_offset": 33188,
"start_offset": 33188,
"end_offset": 33557
},
"content_filter_results": {
"hate": {
"filtered": false,
"severity": "safe"
},
"self_harm": {
"filtered": false,
"severity": "safe"
},
"sexual": {
"filtered": false,
"severity": "safe"
},
"violence": {
"filtered": false,
"severity": "safe"
}
},
"finish_reason": null,
"index": 0
}
],
"created": 0,
"id": "",
"model": "",
"object": ""
}
16 changes: 16 additions & 0 deletions openai-client/src/commonTest/resources/json/chatChunk.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"choices": [
{
"delta": {
"content": " engineering"
},
"finish_reason": null,
"index": 0
}
],
"created": 1716855566,
"id": "chatcmpl-9TeqkT3BJs5zXQq12b204deXcY5nj",
"model": "gpt-4o-2024-05-13",
"object": "chat.completion.chunk",
"system_fingerprint": "fp_5f4bad809a"
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.aallam.openai.api.chat;

import com.aallam.openai.api.BetaOpenAI
import com.aallam.openai.api.core.FinishReason
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
Expand All @@ -19,7 +18,17 @@ public data class ChatChunk(
/**
* The generated chat message.
*/
@SerialName("delta") public val delta: ChatDelta,
@SerialName("delta") public val delta: ChatDelta? = null,

/**
* Azure content filter offsets
*/
@SerialName("content_filter_offsets") public val contentFilterOffsets: ContentFilterOffsets? = null,

/**
* Azure content filter results
*/
@SerialName("content_filter_results") public val contentFilterResults: ContentFilterResults? = null,

/**
* The reason why OpenAI stopped generating.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ public data class ChatMessage(
* Tool call ID.
*/
@SerialName("tool_call_id") public val toolCallId: ToolId? = null,

/**
* Azure Content Filter Results
*/
@SerialName("content_filter_results") public val contentFilterResults: List<ContentFilterResults>? = null,

/**
* Azure Content Filter Offsets
*/
@SerialName("content_filter_offsets") public val contentFilterOffsets: List<ContentFilterOffsets>? = null,
) {

public constructor(
Expand All @@ -54,13 +64,17 @@ public data class ChatMessage(
functionCall: FunctionCall? = null,
toolCalls: List<ToolCall>? = null,
toolCallId: ToolId? = null,
contentFilterResults: List<ContentFilterResults>? = null,
contentFilterOffsets: List<ContentFilterOffsets>? = null,
) : this(
role = role,
messageContent = content?.let { TextContent(it) },
name = name,
functionCall = functionCall,
toolCalls = toolCalls,
toolCallId = toolCallId,
contentFilterOffsets = contentFilterOffsets,
contentFilterResults = contentFilterResults,
)

public constructor(
Expand All @@ -70,13 +84,17 @@ public data class ChatMessage(
functionCall: FunctionCall? = null,
toolCalls: List<ToolCall>? = null,
toolCallId: ToolId? = null,
contentFilterResults: List<ContentFilterResults>? = null,
contentFilterOffsets: List<ContentFilterOffsets>? = null,
) : this(
role = role,
messageContent = content?.let { ListContent(it) },
name = name,
functionCall = functionCall,
toolCalls = toolCalls,
toolCallId = toolCallId,
contentFilterOffsets = contentFilterOffsets,
contentFilterResults = contentFilterResults,
)

val content: String?
Expand Down Expand Up @@ -282,6 +300,16 @@ public class ChatMessageBuilder {
*/
public var toolCalls: List<ToolCall>? = null

/**
* Azure content filter offsets
*/
public var contentFilterOffsets: List<ContentFilterOffsets>? = null

/**
* Azure content filter results
*/
public var contentFilterResults: List<ContentFilterResults>? = null

/**
* Tool call ID.
*/
Expand Down Expand Up @@ -313,6 +341,8 @@ public class ChatMessageBuilder {
functionCall = functionCall,
toolCalls = toolCalls,
toolCallId = toolCallId,
contentFilterOffsets = contentFilterOffsets,
contentFilterResults = contentFilterResults,
)
}
}
Expand Down
Loading

0 comments on commit 5883420

Please sign in to comment.