Skip to content

Commit

Permalink
feat: Enable GPT-4 in OpenAIPrompt (#2248)
Browse files Browse the repository at this point in the history
* Add OpenAIChatCompletion to OpenAIPrompt

* Parametrize system prompt

* Update cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala

Update default system prompt


* Update cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala

remove unneeded comment


* Update cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala

remove unneeded comment


* Allow naming of messsage column and relevant tests

* Comment fixes

* Fixing Fuzzing test errors

---------

Co-authored-by: Shyam Sai <[email protected]>
  • Loading branch information
sss04 and Shyam Sai authored Jul 16, 2024
1 parent b19b991 commit c033077
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails
import com.microsoft.azure.synapse.ml.param.ServiceParam
import com.microsoft.azure.synapse.ml.services._
import org.apache.spark.ml.PipelineModel
import org.apache.spark.ml.param.{Param, Params}
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import spray.json.DefaultJsonProtocol._
Expand Down Expand Up @@ -40,6 +41,16 @@ trait HasPromptInputs extends HasServiceParams {

}

trait HasMessagesInput extends Params {
val messagesCol: Param[String] = new Param[String](
this, "messagesCol", "The column messages to generate chat completions for," +
" in the chat format. This column should have type Array(Struct(role: String, content: String)).")

def getMessagesCol: String = $(messagesCol)

def setMessagesCol(v: String): this.type = set(messagesCol, v)
}

trait HasOpenAISharedParams extends HasServiceParams with HasAPIVersion {

val deploymentName = new ServiceParam[String](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,10 @@ import scala.language.existentials
object OpenAIChatCompletion extends ComplexParamsReadable[OpenAIChatCompletion]

class OpenAIChatCompletion(override val uid: String) extends OpenAIServicesBase(uid)
with HasOpenAITextParams with HasOpenAICognitiveServiceInput
with HasOpenAITextParams with HasMessagesInput with HasOpenAICognitiveServiceInput
with HasInternalJsonOutputParser with SynapseMLLogging {
logClass(FeatureNames.AiServices.OpenAI)

val messagesCol: Param[String] = new Param[String](
this, "messagesCol", "The column messages to generate chat completions for," +
" in the chat format. This column should have type Array(Struct(role: String, content: String)).")

def getMessagesCol: String = $(messagesCol)

def setMessagesCol(v: String): this.type = set(messagesCol, v)

def this() = this(Identifiable.randomUID("OpenAIChatCompletion"))

def urlPath: String = ""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import com.microsoft.azure.synapse.ml.param.StringStringMapParam
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Transformer}
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.{Column, DataFrame, Dataset, functions => F, types => T}

Expand All @@ -20,7 +21,7 @@ import scala.collection.JavaConverters._
object OpenAIPrompt extends ComplexParamsReadable[OpenAIPrompt]

class OpenAIPrompt(override val uid: String) extends Transformer
with HasOpenAITextParams
with HasOpenAITextParams with HasMessagesInput
with HasErrorCol with HasOutputCol
with HasURL with HasCustomCogServiceDomain with ConcurrencyParams
with HasSubscriptionKey with HasAADToken with HasCustomAuthHeader
Expand Down Expand Up @@ -62,18 +63,30 @@ class OpenAIPrompt(override val uid: String) extends Transformer
set(postProcessingOptions, v.asScala.toMap)

val dropPrompt = new BooleanParam(
this, "dropPrompt", "whether to drop the column of prompts after templating")
this, "dropPrompt", "whether to drop the column of prompts after templating (when using legacy models)")

def getDropPrompt: Boolean = $(dropPrompt)

def setDropPrompt(value: Boolean): this.type = set(dropPrompt, value)

val systemPrompt = new Param[String](
this, "systemPrompt", "The initial system prompt to be used.")

def getSystemPrompt: String = $(systemPrompt)

def setSystemPrompt(value: String): this.type = set(systemPrompt, value)

private val defaultSystemPrompt = "You are an AI chatbot who wants to answer user's questions and complete tasks. " +
"Follow their instructions carefully and be brief if they don't say otherwise."

setDefault(
postProcessing -> "",
postProcessingOptions -> Map.empty,
outputCol -> (this.uid + "_output"),
errorCol -> (this.uid + "_error"),
messagesCol -> (this.uid + "_messages"),
dropPrompt -> true,
systemPrompt -> defaultSystemPrompt,
timeout -> 360.0
)

Expand All @@ -82,40 +95,77 @@ class OpenAIPrompt(override val uid: String) extends Transformer
}

private val localParamNames = Seq(
"promptTemplate", "outputCol", "postProcessing", "postProcessingOptions", "dropPrompt")
"promptTemplate", "outputCol", "postProcessing", "postProcessingOptions", "dropPrompt", "dropMessages",
"systemPrompt")

override def transform(dataset: Dataset[_]): DataFrame = {
import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions._

logTransform[DataFrame]({
val df = dataset.toDF

val promptColName = df.withDerivativeCol("prompt")

val dfTemplated = df.withColumn(promptColName, Functions.template(getPromptTemplate))

val completion = openAICompletion.setPromptCol(promptColName)

// run completion
val results = completion
.transform(dfTemplated)
.withColumn(getOutputCol,
getParser.parse(F.element_at(F.col(completion.getOutputCol).getField("choices"), 1)
.getField("text")))
.drop(completion.getOutputCol)

if (getDropPrompt) {
results.drop(promptColName)
} else {
results
val completion = openAICompletion
val promptCol = Functions.template(getPromptTemplate)
val createMessagesUDF = udf((userMessage: String) => {
Seq(
OpenAIMessage("system", getSystemPrompt),
OpenAIMessage("user", userMessage)
)
})
completion match {
case chatCompletion: OpenAIChatCompletion =>
val messageColName = getMessagesCol
val dfTemplated = df.withColumn(messageColName, createMessagesUDF(promptCol))
val completionNamed = chatCompletion.setMessagesCol(messageColName)

val results = completionNamed
.transform(dfTemplated)
.withColumn(getOutputCol,
getParser.parse(F.element_at(F.col(completionNamed.getOutputCol).getField("choices"), 1)
.getField("message").getField("content")))
.drop(completionNamed.getOutputCol)

if (getDropPrompt) {
results.drop(messageColName)
} else {
results
}

case completion: OpenAICompletion =>
val promptColName = df.withDerivativeCol("prompt")
val dfTemplated = df.withColumn(promptColName, promptCol)
val completionNamed = completion.setPromptCol(promptColName)

// run completion
val results = completionNamed
.transform(dfTemplated)
.withColumn(getOutputCol,
getParser.parse(F.element_at(F.col(completionNamed.getOutputCol).getField("choices"), 1)
.getField("text")))
.drop(completionNamed.getOutputCol)

if (getDropPrompt) {
results.drop(promptColName)
} else {
results
}
}
}, dataset.columns.length)
}

private def openAICompletion: OpenAICompletion = {
// apply template
val completion = new OpenAICompletion()
private val legacyModels = Set("ada","babbage", "curie", "davinci",
"text-ada-001", "text-babbage-001", "text-curie-001", "text-davinci-002", "text-davinci-003",
"code-cushman-001", "code-davinci-002")

private def openAICompletion: OpenAIServicesBase = {

val completion: OpenAIServicesBase =
if (legacyModels.contains(getDeploymentName)) {
new OpenAICompletion()
}
else {
new OpenAIChatCompletion()
}
// apply all parameters
extractParamMap().toSeq
.filter(p => !localParamNames.contains(p.param.name))
Expand All @@ -136,10 +186,18 @@ class OpenAIPrompt(override val uid: String) extends Transformer
}
}

override def transformSchema(schema: StructType): StructType =
openAICompletion
.transformSchema(schema)
.add(getPostProcessing, getParser.outputSchema)
override def transformSchema(schema: StructType): StructType = {
openAICompletion match {
case chatCompletion: OpenAIChatCompletion =>
chatCompletion
.transformSchema(schema.add(getMessagesCol, StructType(Seq())))
.add(getPostProcessing, getParser.outputSchema)
case completion: OpenAICompletion =>
completion
.transformSchema(schema)
.add(getPostProcessing, getParser.outputSchema)
}
}
}

trait OutputParser {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK

test("Basic Usage JSON") {
prompt.setPromptTemplate(
"""Split a word into prefix and postfix a respond in JSON
|Cherry: {{"prefix": "Che", "suffix": "rry"}}
|{text}:
|""".stripMargin)
"""Split a word into prefix and postfix a respond in JSON
|Cherry: {{"prefix": "Che", "suffix": "rry"}}
|{text}:
|""".stripMargin)
.setPostProcessing("json")
.setPostProcessingOptions(Map("jsonSchema" -> "prefix STRING, suffix STRING"))
.transform(df)
Expand All @@ -62,6 +62,56 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
.foreach(r => assert(r.getStruct(0).getString(0).nonEmpty))
}

lazy val promptGpt4: OpenAIPrompt = new OpenAIPrompt()
.setSubscriptionKey(openAIAPIKey)
.setDeploymentName(deploymentNameGpt4)
.setCustomServiceName(openAIServiceName)
.setOutputCol("outParsed")
.setTemperature(0)

test("Basic Usage - Gpt 4") {
val nonNullCount = promptGpt4
.setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ")
.setPostProcessing("csv")
.transform(df)
.select("outParsed")
.collect()
.count(r => Option(r.getSeq[String](0)).isDefined)

assert(nonNullCount == 3)
}

test("Basic Usage JSON - Gpt 4") {
promptGpt4.setPromptTemplate(
"""Split a word into prefix and postfix a respond in JSON
|Cherry: {{"prefix": "Che", "suffix": "rry"}}
|{text}:
|""".stripMargin)
.setPostProcessing("json")
.setPostProcessingOptions(Map("jsonSchema" -> "prefix STRING, suffix STRING"))
.transform(df)
.select("outParsed")
.where(col("outParsed").isNotNull)
.collect()
.foreach(r => assert(r.getStruct(0).getString(0).nonEmpty))
}

test("Setting and Keeping Messages Col - Gpt 4") {
promptGpt4.setMessagesCol("messages")
.setDropPrompt(false)
.setPromptTemplate(
"""Classify each word as to whether they are an F1 team or not
|ferrari: TRUE
|tomato: FALSE
|{text}:
|""".stripMargin)
.transform(df)
.select("messages")
.where(col("messages").isNotNull)
.collect()
.foreach(r => assert(r.get(0) != null))
}

override def assertDFEq(df1: DataFrame, df2: DataFrame)(implicit eq: Equality[DataFrame]): Unit = {
super.assertDFEq(df1.drop("out", "outParsed"), df2.drop("out", "outParsed"))(eq)
}
Expand Down

0 comments on commit c033077

Please sign in to comment.