Skip to content

Commit

Permalink
Fixing Fuzzing test errors
Browse files Browse the repository at this point in the history
  • Loading branch information
Shyam Sai committed Jul 13, 2024
1 parent fe6b208 commit 8839102
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,6 @@ class OpenAIPrompt(override val uid: String) extends Transformer

def setDropPrompt(value: Boolean): this.type = set(dropPrompt, value)

val dropMessages = new BooleanParam(
this, "dropMessages", "whether to drop the column of messages after templating (when using gpt-4 or higher)")

def getDropMessages: Boolean = $(dropMessages)

def setDropMessages(value: Boolean): this.type = set(dropMessages, value)

val systemPrompt = new Param[String](
this, "systemPrompt", "The initial system prompt to be used.")

Expand All @@ -93,9 +86,8 @@ class OpenAIPrompt(override val uid: String) extends Transformer
errorCol -> (this.uid + "_error"),
messagesCol -> (this.uid + "_messages"),
dropPrompt -> true,
dropMessages -> true,
systemPrompt -> defaultSystemPrompt,
timeout -> 360.0,
timeout -> 360.0
)

override def setCustomServiceName(v: String): this.type = {
Expand Down Expand Up @@ -133,7 +125,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer
.getField("message").getField("content")))
.drop(completionNamed.getOutputCol)

if (getDropMessages) {
if (getDropPrompt) {
results.drop(messageColName)
} else {
results
Expand Down Expand Up @@ -197,16 +189,14 @@ class OpenAIPrompt(override val uid: String) extends Transformer
override def transformSchema(schema: StructType): StructType = {
openAICompletion match {
case chatCompletion: OpenAIChatCompletion =>
chatCompletion.setMessagesCol(getMessagesCol)
chatCompletion
.transformSchema(schema)
.transformSchema(schema.add(getMessagesCol, StructType(Seq())))
.add(getPostProcessing, getParser.outputSchema)
case completion: OpenAICompletion =>
completion
.transformSchema(schema)
.add(getPostProcessing, getParser.outputSchema)
}

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK

test("Setting and Keeping Messages Col - Gpt 4") {
promptGpt4.setMessagesCol("messages")
.setDropMessages(false)
.setDropPrompt(false)
.setPromptTemplate(
"""Classify each word as to whether they are an F1 team or not
|ferrari: TRUE
Expand Down

0 comments on commit 8839102

Please sign in to comment.