Skip to content

Commit

Permalink
add prompt template for reasoning (#300)
Browse files Browse the repository at this point in the history
* add prompt template for reasoning

* fix types for function
  • Loading branch information
khai-meetkai authored Dec 23, 2024
1 parent 8f8bbaa commit 327cbf0
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 1 deletion.
3 changes: 2 additions & 1 deletion functionary/prompt_template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from functionary.prompt_template.llava_prompt_template import LlavaLlama
from functionary.prompt_template.prompt_template_v1 import PromptTemplateV1
from functionary.prompt_template.prompt_template_v2 import PromptTemplateV2
from functionary.prompt_template.llama31_reasoning_prompt_template import Llama31ReasoningTemplate


def get_available_prompt_template_versions() -> List[PromptTemplate]:
Expand All @@ -28,7 +29,7 @@ def get_available_prompt_template_versions() -> List[PromptTemplate]:
# directly add LLavaLlama as it is not a direct subclass of PromptTemplate but the subclass of: Llama3TemplateV3
# we don't use get_prompt_template or this will return the parent class
all_templates_obj.append(LlavaLlama.get_prompt_template())

all_templates_obj.append(Llama31ReasoningTemplate.get_prompt_template())
return all_templates_obj


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
{# version=v3-llama3.1 #}{%- if not tools is defined -%}
{%- set tools = none -%}
{%- endif -%}

{%- set has_code_interpreter = tools | selectattr("type", "equalto", "code_interpreter") | list | length > 0 -%}
{%- if has_code_interpreter -%}
{%- set tools = tools | rejectattr("type", "equalto", "code_interpreter") | list -%}
{%- endif -%}

{%- set has_reasoning = tools | selectattr("type", "equalto", "reasoning") | list | length > 0 -%}
{%- if has_reasoning -%}
{%- set tools = tools | rejectattr("type", "equalto", "reasoning") | list -%}
{%- endif -%}

{#- System message + builtin tools #}
{{- bos_token + "<|start_header_id|>system<|end_header_id|>\n\n" }}
{%- if has_reasoning %}
{{- "Reasoning Mode: On\n\n" }}
{%- else -%}
{{ "Reasoning Mode: Off\n\n" }}
{%- endif %}
{%- if has_code_interpreter %}
{{- "Environment: ipython\n\n" }}
{%- else -%}
{{ "\n"}}
{%- endif %}
{{- "Cutting Knowledge Date: December 2023\n\n" }}
{%- if tools %}
{{- "\nYou have access to the following functions:\n\n" }}
{%- for t in tools %}
{%- if "type" in t -%}
{{ "Use the function '" + t["function"]["name"] + "' to '" + t["function"]["description"] + "'\n" + t["function"] | tojson() }}
{%- else -%}
{{ "Use the function '" + t["name"] + "' to '" + t["description"] + "'\n" + t | tojson }}
{%- endif -%}
{{- "\n\n" }}
{%- endfor %}
{{- '\nThink very carefully before calling functions.\nIf a you choose to call a function ONLY reply in the following format:\n<{start_tag}={function_name}>{parameters}{end_tag}\nwhere\n\nstart_tag => `<function`\nparameters => a JSON dict with the function argument name as key and function argument value as value.\nend_tag => `</function>`\n\nHere is an example,\n<function=example_function_name>{"example_name": "example_value"}</function>\n\nReminder:\n- If looking for real time information use relevant functions before falling back to brave_search\n- Function calls MUST follow the specified format, start with <function= and end with </function>\n- Required parameters MUST be specified\n- Only call one function at a time\n- Put the entire function call reply on one line\n\n' -}}
{%- endif %}
{{- "<|eot_id|>" -}}

{%- for message in messages -%}
{%- if message['role'] == 'user' or message['role'] == 'system' -%}
{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}
{%- elif message['role'] == 'tool' -%}
{{ '<|start_header_id|>ipython<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}
{%- else -%}
{%- if (message['content'] and message['content']|length > 0) or ('tool_calls' in message and message['tool_calls'] and message['tool_calls']|length > 0) -%}
{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'}}
{%- endif -%}
{%- if message['content'] and message['content']|length > 0 -%}
{{ message['content'] }}
{%- endif -%}
{%- if 'tool_calls' in message and message['tool_calls'] and message['tool_calls']|length > 0 -%}
{%- for tool_call in message['tool_calls'] -%}
{%- if tool_call["function"]["name"] == "python" -%}
{{ '<|python_tag|>' + tool_call['function']['arguments'] }}
{%- else -%}
{{ '<function=' + tool_call['function']['name'] + '>' + tool_call['function']['arguments'] + '</function>' }}
{%- endif -%}
{%- endfor -%}
{{ '<|eom_id|>' }}
{%- elif message['content'] and message['content']|length > 0 -%}
{{ '<|eot_id|>' }}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
{%- endif -%}
17 changes: 17 additions & 0 deletions functionary/prompt_template/llama31_reasoning_prompt_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import Dict, List, Optional

from functionary.prompt_template.llama31_prompt_template import Llama31Template


class Llama31ReasoningTemplate(Llama31Template):
version = "v3-llama3.1-reasoning"

def get_prompt_from_messages(
self,
messages: List[Dict],
tools_or_functions: Optional[List[Dict]] = None,
bos_token: Optional[str] = "",
add_generation_prompt: bool = False,
) -> str:
reasoning_tool = {"type": "reasoning"}
return super().get_prompt_from_messages(messages, tools_or_functions + [reasoning_tool], bos_token, add_generation_prompt)
2 changes: 2 additions & 0 deletions tests/test_request_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
LlavaLlama,
PromptTemplate,
PromptTemplateV2,
Llama31ReasoningTemplate,
get_available_prompt_template_versions,
)
from functionary.prompt_template.prompt_utils import (
Expand Down Expand Up @@ -150,6 +151,7 @@ def __init__(self, *args, **kwargs):
Llama3Template: "meetkai/functionary-small-v2.5",
Llama3TemplateV3: "meetkai/functionary-medium-v3.0",
Llama31Template: "meetkai/functionary-small-v3.1",
Llama31ReasoningTemplate: "meetkai/functionary-small-v3.1",
LlavaLlama: "lmms-lab/llama3-llava-next-8b",
}
self.default_text_str = "Normal text generation"
Expand Down

0 comments on commit 327cbf0

Please sign in to comment.