Skip to content

Commit

Permalink
Comment fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Shyam Sai committed Jul 10, 2024
1 parent e8fcd6e commit fe6b208
Showing 1 changed file with 7 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,11 @@ class OpenAIPrompt(override val uid: String) extends Transformer
postProcessingOptions -> Map.empty,
outputCol -> (this.uid + "_output"),
errorCol -> (this.uid + "_error"),
messagesCol -> (this.uid + "_messages"),
dropPrompt -> true,
dropMessages -> true,
systemPrompt -> defaultSystemPrompt,
timeout -> 360.0
timeout -> 360.0,
)

override def setCustomServiceName(v: String): this.type = {
Expand All @@ -121,7 +122,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer
})
completion match {
case chatCompletion: OpenAIChatCompletion =>
val messageColName = df.withDerivativeCol("messages")
val messageColName = getMessagesCol
val dfTemplated = df.withColumn(messageColName, createMessagesUDF(promptCol))
val completionNamed = chatCompletion.setMessagesCol(messageColName)

Expand Down Expand Up @@ -160,8 +161,9 @@ class OpenAIPrompt(override val uid: String) extends Transformer
}, dataset.columns.length)
}

private val legacyModels = Set("gpt-35-turbo", "gpt-35-turbo-16k", "gpt-35-turbo-instruct",
"text-davinci-002", "text-davinci-003")
private val legacyModels = Set("ada","babbage", "curie", "davinci",
"text-ada-001", "text-babbage-001", "text-curie-001", "text-davinci-002", "text-davinci-003",
"code-cushman-001", "code-davinci-002")

private def openAICompletion: OpenAIServicesBase = {

Expand Down Expand Up @@ -195,7 +197,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer
override def transformSchema(schema: StructType): StructType = {
openAICompletion match {
case chatCompletion: OpenAIChatCompletion =>
chatCompletion.setMessagesCol("messages")
chatCompletion.setMessagesCol(getMessagesCol)
chatCompletion
.transformSchema(schema)
.add(getPostProcessing, getParser.outputSchema)
Expand Down

0 comments on commit fe6b208

Please sign in to comment.