Skip to content

Commit

Permalink
refactor(runner): streamline template compilation and execution #157
Browse files Browse the repository at this point in the history
- Refactored `processTemplateCompile` and `prepareExecute` methods to simplify logic and improve readability.
- Removed redundant code and consolidated variable handling.
- Updated `ShireActionPromptBuilder` to use `compileFileContext` for consistent prompt generation.
- Improved error handling
  • Loading branch information
phodal committed Jan 2, 2025
1 parent d35a870 commit 36d261f
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 117 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ class OnStreamingService {
fun onStreaming(project: Project, chunk: String) {
map.forEach { (sign, service) ->
try {
service.onStreaming(project, chunk, sign.args)
service.
onStreaming(project, chunk, sign.args)
} catch (e: Exception) {
ShirelangNotifications.error(project, "Error on streaming service: ${e.message}")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import com.intellij.openapi.project.Project
import com.phodal.shirecore.config.ShireActionLocation
import com.phodal.shirecore.provider.ide.ShirePromptBuilder
import com.phodal.shirelang.actions.base.DynamicShireActionService
import com.phodal.shirelang.run.runner.ShireRunner
import kotlinx.coroutines.runBlocking

class ShireActionPromptBuilder : ShirePromptBuilder {
override fun build(project: Project, actionLocation: String, originPrompt: String): String {
Expand All @@ -12,7 +14,11 @@ class ShireActionPromptBuilder : ShirePromptBuilder {
val action = DynamicShireActionService.getInstance(project).getActions(location)
.firstOrNull() ?: return originPrompt

val initVariables = mapOf("chatPrompt" to originPrompt)
val finalPrompt = runBlocking {
ShireRunner.compileFileContext(project, action.shireFile, initVariables)
}.finalPrompt

return action.shireFile.text?.replace("\$chatPrompt", originPrompt) ?: ""
return finalPrompt
}
}
265 changes: 151 additions & 114 deletions shirelang/src/main/kotlin/com/phodal/shirelang/run/runner/ShireRunner.kt
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,13 @@ class ShireRunner(
private val cancelListeners = mutableSetOf<(String) -> Unit>()

suspend fun execute(parsedResult: ShireParsedResult): String? {
prepareExecute(parsedResult)
prepareExecute(parsedResult, compiledVariables, project, console)

val runResult = CompletableFuture<String>()

val varsMap = variableMap.toMutableMap()
val varsMap = variableFromPostProcessorContext(variableMap)

val data = PostProcessorContext.getData()
val variables = data?.compiledVariables
if (variables?.get("output") != null && variableMap["output"] == null) {
varsMap["output"] = variables["output"].toString()
}

val runnerContext = processTemplateCompile(parsedResult, varsMap)
val runnerContext = processTemplateCompile(parsedResult, varsMap, project, configuration, console)
if (runnerContext.hasError) {
processHandler.exitWithError()
return null
Expand Down Expand Up @@ -127,51 +121,6 @@ class ShireRunner(
}
}

private suspend fun processTemplateCompile(
compileResult: ShireParsedResult, variableMap: Map<String, String>,
): ShireRunnerContext {
val hobbitHole = compileResult.config
val editor = ActionLocationEditor.provide(project, hobbitHole?.actionLocation)

val templateCompiler =
ShireTemplateCompiler(project, hobbitHole, compileResult.variableTable, compileResult.shireOutput, editor)

variableMap.forEach { (key, value) ->
templateCompiler.putCustomVariable(key, value)
}

val promptTextTrim = templateCompiler.compile().trim()
val compiledVariables = templateCompiler.compiledVariables

PostProcessorContext.getData()?.lastTaskOutput?.let {
templateCompiler.putCustomVariable("output", it)
}

if (console != null) {
printCompiledOutput(console, promptTextTrim, configuration)
}

var hasError = false

if (promptTextTrim.isEmpty()) {
console?.print("No content to run", ConsoleViewContentType.ERROR_OUTPUT)
hasError = true
}

if (promptTextTrim.contains(SHIRE_ERROR)) {
hasError = true
}

return ShireRunnerContext(
hobbitHole,
editor = editor,
compileResult,
promptTextTrim,
hasError,
compiledVariables
)
}

fun executeNormalUiTask(runData: ShireRunnerContext, postFunction: PostFunction) {
val agent = runData.compileResult.executeAgent
val hobbitHole = runData.hole
Expand Down Expand Up @@ -235,64 +184,6 @@ class ShireRunner(
}
}

private fun printCompiledOutput(
console: ConsoleViewWrapperBase,
promptText: String,
shireConfiguration: ShireConfiguration,
) {
console.print("Shire Script: ${shireConfiguration.getScriptPath()}\n", ConsoleViewContentType.SYSTEM_OUTPUT)
console.print("Shire Script Compile output:\n", ConsoleViewContentType.SYSTEM_OUTPUT)
PostProcessorContext.getData()?.llmModelName?.let {
console.print("Used model: $it\n", ConsoleViewContentType.SYSTEM_OUTPUT)
}

promptText.split("\n").forEach {
when {
it.contains(SHIRE_ERROR) -> {
console.print(it, ConsoleViewContentType.LOG_ERROR_OUTPUT)
}

else -> {
console.print(it, ConsoleViewContentType.USER_INPUT)
}
}
console.print("\n", ConsoleViewContentType.NORMAL_OUTPUT)
}

console.print("\n--------------------\n", ConsoleViewContentType.NORMAL_OUTPUT)
}

fun prepareExecute(parsedResult: ShireParsedResult) {
val hobbitHole = parsedResult.config
val editor = FileEditorManager.getInstance(project).selectedTextEditor
hobbitHole?.pickupElement(project, editor)

val file = runReadAction {
editor?.let { PsiManager.getInstance(project).findFile(it.virtualFile) }
}

val context = PostProcessorContext.getData() ?: PostProcessorContext(
currentFile = file,
currentLanguage = file?.language,
editor = editor,
compiledVariables = compiledVariables,
llmModelName = hobbitHole?.model
)

PostProcessorContext.updateContextAndVariables(context)

val vars: MutableMap<String, Any?> = compiledVariables.toMutableMap()
hobbitHole?.executeBeforeStreamingProcessor(project, context, console, vars)

val streamingService = project.getService(OnStreamingService::class.java)
streamingService.clearStreamingService()
hobbitHole?.onStreaming?.forEach {
streamingService.registerStreamingService(it, console)
}

hobbitHole?.setupStreamingEndProcessor(project, context)
}

@Synchronized
fun addCancelListener(listener: (String) -> Unit) {
if (isCanceled) cancel(listener)
Expand All @@ -316,8 +207,154 @@ class ShireRunner(
companion object {
fun preAnalysisSyntax(shireFile: ShireFile, project: Project): ShireParsedResult {
val syntaxAnalyzer = ShireSyntaxAnalyzer(project, shireFile, ActionLocationEditor.defaultEditor(project))
val parsedResult = syntaxAnalyzer.parse()
return parsedResult
return syntaxAnalyzer.parse()
}

fun prepareExecute(
parsedResult: ShireParsedResult,
variables: Map<String, Any>,
project: Project,
consoleView: ShireConsoleView?,
): PostProcessorContext {
val hobbitHole = parsedResult.config
val editor = FileEditorManager.getInstance(project).selectedTextEditor
hobbitHole?.pickupElement(project, editor)

val file = runReadAction {
editor?.let { PsiManager.getInstance(project).findFile(it.virtualFile) }
}

val context = PostProcessorContext.getData() ?: PostProcessorContext(
currentFile = file,
currentLanguage = file?.language,
editor = editor,
compiledVariables = variables,
llmModelName = hobbitHole?.model
)

PostProcessorContext.updateContextAndVariables(context)

val vars: MutableMap<String, Any?> = variables.toMutableMap()
hobbitHole?.executeBeforeStreamingProcessor(project, context, consoleView, vars)

val streamingService = project.getService(OnStreamingService::class.java)
streamingService.clearStreamingService()
hobbitHole?.onStreaming?.forEach {
streamingService.registerStreamingService(it, consoleView)
}

hobbitHole?.setupStreamingEndProcessor(project, context)

return context
}

private suspend fun processTemplateCompile(
compileResult: ShireParsedResult,
variableMap: Map<String, String>,
project: Project,
shireConfiguration: ShireConfiguration?,
shireConsoleView: ShireConsoleView?,
): ShireRunnerContext {
val hobbitHole = compileResult.config
val editor = ActionLocationEditor.provide(project, hobbitHole?.actionLocation)

val templateCompiler =
ShireTemplateCompiler(
project,
hobbitHole,
compileResult.variableTable,
compileResult.shireOutput,
editor
)

variableMap.forEach { (key, value) ->
templateCompiler.putCustomVariable(key, value)
}

val promptTextTrim = templateCompiler.compile().trim()
val compiledVariables = templateCompiler.compiledVariables

PostProcessorContext.getData()?.lastTaskOutput?.let {
templateCompiler.putCustomVariable("output", it)
}

if (shireConsoleView != null && shireConfiguration != null) {
printCompiledOutput(shireConsoleView, promptTextTrim, shireConfiguration)
}

var hasError = false

if (promptTextTrim.isEmpty()) {
shireConsoleView?.print("No content to run", ConsoleViewContentType.ERROR_OUTPUT)
hasError = true
}

if (promptTextTrim.contains(SHIRE_ERROR)) {
hasError = true
}

return ShireRunnerContext(
hobbitHole,
editor = editor,
compileResult,
promptTextTrim,
hasError,
compiledVariables
)
}

private fun printCompiledOutput(
console: ConsoleViewWrapperBase,
promptText: String,
shireConfiguration: ShireConfiguration,
) {
console.print("Shire Script: ${shireConfiguration.getScriptPath()}\n", ConsoleViewContentType.SYSTEM_OUTPUT)
console.print("Shire Script Compile output:\n", ConsoleViewContentType.SYSTEM_OUTPUT)
PostProcessorContext.getData()?.llmModelName?.let {
console.print("Used model: $it\n", ConsoleViewContentType.SYSTEM_OUTPUT)
}

promptText.split("\n").forEach {
when {
it.contains(SHIRE_ERROR) -> {
console.print(it, ConsoleViewContentType.LOG_ERROR_OUTPUT)
}

else -> {
console.print(it, ConsoleViewContentType.USER_INPUT)
}
}
console.print("\n", ConsoleViewContentType.NORMAL_OUTPUT)
}

console.print("\n--------------------\n", ConsoleViewContentType.NORMAL_OUTPUT)
}

suspend fun compileFileContext(
project: Project,
shireFile: ShireFile,
initVariables: Map<String, String>,
): ShireRunnerContext {
val parsedResult = preAnalysisSyntax(shireFile, project)
val variables = variableFromPostProcessorContext(initVariables)

/// add context
prepareExecute(parsedResult, variables, project, null)

val runnerContext = processTemplateCompile(parsedResult, variables, project, null, null)
return runnerContext

}

private fun variableFromPostProcessorContext(initValue: Map<String, String>): MutableMap<String, String> {
val varsMap = initValue.toMutableMap()
val data = PostProcessorContext.getData()
val variables = data?.compiledVariables
if (variables?.get("output") != null && initValue["output"] == null) {
varsMap["output"] = variables["output"].toString()
}

return varsMap
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class ShireInlineChatService : Disposable {
}

fun prompt(project: Project, prompt: String): String {
return ShirePromptBuilder.provide()?.build(project, ShireActionLocation.INPUT_BOX.name, prompt) ?: prompt
return ShirePromptBuilder.provide()?.build(project, ShireActionLocation.INLINE_CHAT.name, prompt) ?: prompt
}

companion object {
Expand Down

0 comments on commit 36d261f

Please sign in to comment.