diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala index 70d718c9a6..31c56dc80c 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala @@ -198,6 +198,12 @@ trait HasCustomHeaders extends HasServiceParams { setScalarParam(customHeaders, v) } + // For Pyspark compatability accept Java HashMap as input to parameter + // py4J only natively supports conversions from Python Dict to Java HashMap + def setCustomHeaders(v: java.util.HashMap[String,String]): this.type = { + setCustomHeaders(v.asScala.toMap) + } + def getCustomHeaders: Map[String, String] = getScalarParam(customHeaders) } diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletionSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletionSuite.scala index 9fd5d4b3a8..89458decf8 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletionSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletionSuite.scala @@ -154,6 +154,7 @@ class OpenAIChatCompletionSuite extends TransformerFuzzing[OpenAIChatCompletion] ignore("Custom EndPoint") { lazy val accessToken: String = sys.env.getOrElse("CUSTOM_ACCESS_TOKEN", "") lazy val customRootUrlValue: String = sys.env.getOrElse("CUSTOM_ROOT_URL", "") + lazy val customHeadersValues: Map[String, String] = Map("X-ModelType" -> "gpt-4-turbo-chat-completions") val customEndpointCompletion = new OpenAIChatCompletion() .setCustomUrlRoot(customRootUrlValue) @@ -167,8 +168,7 @@ class OpenAIChatCompletionSuite extends TransformerFuzzing[OpenAIChatCompletion] .setCustomServiceName(openAIServiceName) } else { customEndpointCompletion.setAADToken(accessToken) - .setCustomHeaders(Map("X-ModelType" -> "gpt-4-turbo-chat-completions", - "X-ScenarioGUID" -> "7687c733-45b0-425b-82b3-05eb4eb70247")) + .setCustomHeaders(customHeadersValues) } testCompletion(customEndpointCompletion, goodDf)