Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support LLM jinja2 template prompt #3968

Merged
merged 21 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion api/core/helper/code_executor/code_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from config import get_env
from core.helper.code_executor.javascript_transformer import NodeJsTemplateTransformer
from core.helper.code_executor.jina2_transformer import Jinja2TemplateTransformer
from core.helper.code_executor.jinja2_transformer import Jinja2TemplateTransformer
Yeuoly marked this conversation as resolved.
Show resolved Hide resolved
from core.helper.code_executor.python_transformer import PythonTemplateTransformer

# Code Executor
Expand Down
17 changes: 17 additions & 0 deletions api/core/helper/code_executor/jinja2_formatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from core.helper.code_executor.code_executor import CodeExecutor


class Jinja2Formatter:
@classmethod
def format(cls, template: str, inputs: str) -> str:
"""
Format template
:param template: template
:param inputs: inputs
:return:
"""
result = CodeExecutor.execute_workflow_code_template(
language='jinja2', code=template, inputs=inputs
)

return result['result']
65 changes: 40 additions & 25 deletions api/core/prompt/advanced_prompt_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file.file_obj import FileVar
from core.helper.code_executor.jinja2_formatter import Jinja2Formatter
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
Expand Down Expand Up @@ -80,29 +81,35 @@ def _get_completion_model_prompt_messages(self,

prompt_messages = []

prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
if prompt_template.edition_type == 'basic' or not prompt_template.edition_type:
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}

prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)

if memory and memory_config:
role_prefix = memory_config.role_prefix
prompt_inputs = self._set_histories_variable(
memory=memory,
memory_config=memory_config,
raw_prompt=raw_prompt,
role_prefix=role_prefix,
prompt_template=prompt_template,
prompt_inputs=prompt_inputs,
model_config=model_config
)
if memory and memory_config:
role_prefix = memory_config.role_prefix
prompt_inputs = self._set_histories_variable(
memory=memory,
memory_config=memory_config,
raw_prompt=raw_prompt,
role_prefix=role_prefix,
prompt_template=prompt_template,
prompt_inputs=prompt_inputs,
model_config=model_config
)

if query:
prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs)
if query:
prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs)

prompt = prompt_template.format(
prompt_inputs
)
prompt = prompt_template.format(
prompt_inputs
)
else:
prompt = raw_prompt
prompt_inputs = inputs

prompt = Jinja2Formatter.format(prompt, prompt_inputs)

if files:
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
Expand Down Expand Up @@ -135,14 +142,22 @@ def _get_chat_model_prompt_messages(self,
for prompt_item in raw_prompt_list:
raw_prompt = prompt_item.text

prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
if prompt_item.edition_type == 'basic' or not prompt_item.edition_type:
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}

prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)

prompt = prompt_template.format(
prompt_inputs
)
prompt = prompt_template.format(
prompt_inputs
)
elif prompt_item.edition_type == 'jinja2':
prompt = raw_prompt
prompt_inputs = inputs

prompt = Jinja2Formatter.format(prompt, prompt_inputs)
else:
raise ValueError(f'Invalid edition type: {prompt_item.edition_type}')

if prompt_item.role == PromptMessageRole.USER:
prompt_messages.append(UserPromptMessage(content=prompt))
Expand Down
4 changes: 3 additions & 1 deletion api/core/prompt/entities/advanced_prompt_entities.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Literal, Optional

from pydantic import BaseModel

Expand All @@ -11,13 +11,15 @@ class ChatModelMessage(BaseModel):
"""
text: str
role: PromptMessageRole
edition_type: Optional[Literal['basic', 'jinja2']]


class CompletionModelPromptTemplate(BaseModel):
"""
Completion Model Prompt Template.
"""
text: str
edition_type: Optional[Literal['basic', 'jinja2']]


class MemoryConfig(BaseModel):
Expand Down
21 changes: 20 additions & 1 deletion api/core/workflow/nodes/llm/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector


class ModelConfig(BaseModel):
Expand Down Expand Up @@ -37,13 +38,31 @@ class Configs(BaseModel):
enabled: bool
configs: Optional[Configs] = None

class PromptConfig(BaseModel):
"""
Prompt Config.
"""
jinja2_variables: Optional[list[VariableSelector]] = None

class LLMNodeChatModelMessage(ChatModelMessage):
"""
LLM Node Chat Model Message.
"""
jinja2_text: Optional[str] = None

class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
"""
LLM Node Chat Model Prompt Template.
"""
jinja2_text: Optional[str] = None

class LLMNodeData(BaseNodeData):
"""
LLM Node Data.
"""
model: ModelConfig
prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate]
prompt_template: Union[list[LLMNodeChatModelMessage], LLMNodeCompletionModelPromptTemplate]
prompt_config: Optional[PromptConfig] = None
memory: Optional[MemoryConfig] = None
context: ContextConfig
vision: VisionConfig
128 changes: 119 additions & 9 deletions api/core/workflow/nodes/llm/llm_node.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json
from collections.abc import Generator
from copy import deepcopy
from typing import Optional, cast

from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
Expand All @@ -17,11 +19,15 @@
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.llm.entities import LLMNodeData, ModelConfig
from core.workflow.nodes.llm.entities import (
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
LLMNodeData,
ModelConfig,
)
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from models.model import Conversation
Expand All @@ -39,16 +45,24 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult:
:param variable_pool: variable pool
:return:
"""
node_data = self.node_data
node_data = cast(self._node_data_cls, node_data)
node_data = cast(LLMNodeData, deepcopy(self.node_data))

node_inputs = None
process_data = None

try:
# init messages template
node_data.prompt_template = self._transform_chat_messages(node_data.prompt_template)

# fetch variables and fetch values from variable pool
inputs = self._fetch_inputs(node_data, variable_pool)

# fetch jinja2 inputs
jinja_inputs = self._fetch_jinja_inputs(node_data, variable_pool)

# merge inputs
inputs.update(jinja_inputs)

node_inputs = {}

# fetch files
Expand Down Expand Up @@ -183,6 +197,86 @@ def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage
usage = LLMUsage.empty_usage()

return full_text, usage

def _transform_chat_messages(self,
messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
"""
Transform chat messages

:param messages: chat messages
:return:
"""

if isinstance(messages, LLMNodeCompletionModelPromptTemplate):
if messages.edition_type == 'jinja2':
messages.text = messages.jinja2_text

return messages

for message in messages:
if message.edition_type == 'jinja2':
message.text = message.jinja2_text

return messages

def _fetch_jinja_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
"""
Fetch jinja inputs
:param node_data: node data
:param variable_pool: variable pool
:return:
"""
variables = {}

if not node_data.prompt_config:
return variables

for variable_selector in node_data.prompt_config.jinja2_variables or []:
variable = variable_selector.variable
value = variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector
)

def parse_dict(d: dict) -> str:
"""
Parse dict into string
"""
# check if it's a context structure
if 'metadata' in d and '_source' in d['metadata'] and 'content' in d:
return d['content']

# else, parse the dict
try:
return json.dumps(d, ensure_ascii=False)
except Exception:
return str(d)

if isinstance(value, str):
value = value
elif isinstance(value, list):
result = ''
for item in value:
if isinstance(item, dict):
result += parse_dict(item)
elif isinstance(item, str):
result += item
elif isinstance(item, int | float):
result += str(item)
else:
result += str(item)
result += '\n'
value = result.strip()
elif isinstance(value, dict):
value = parse_dict(value)
elif isinstance(value, int | float):
value = str(value)
else:
value = str(value)

variables[variable] = value

return variables

def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
"""
Expand Down Expand Up @@ -531,14 +625,12 @@ def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage:
db.session.commit()

@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
def _extract_variable_selector_to_variable_mapping(cls, node_data: LLMNodeData) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
node_data = node_data
node_data = cast(cls._node_data_cls, node_data)

prompt_template = node_data.prompt_template

Expand Down Expand Up @@ -571,6 +663,22 @@ def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData)
if node_data.memory:
variable_mapping['#sys.query#'] = ['sys', SystemVariable.QUERY.value]

if node_data.prompt_config:
enable_jinja = False

if isinstance(prompt_template, list):
for prompt in prompt_template:
if prompt.edition_type == 'jinja2':
enable_jinja = True
break
else:
if prompt_template.edition_type == 'jinja2':
enable_jinja = True

if enable_jinja:
for variable_selector in node_data.prompt_config.jinja2_variables or []:
variable_mapping[variable_selector.variable] = variable_selector.value_selector

return variable_mapping

@classmethod
Expand All @@ -588,7 +696,8 @@ def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"prompts": [
{
"role": "system",
"text": "You are a helpful AI assistant."
"text": "You are a helpful AI assistant.",
"edition_type": "basic"
}
]
},
Expand All @@ -600,7 +709,8 @@ def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"prompt": {
"text": "Here is the chat histories between human and assistant, inside "
"<histories></histories> XML tags.\n\n<histories>\n{{"
"#histories#}}\n</histories>\n\n\nHuman: {{#sys.query#}}\n\nAssistant:"
"#histories#}}\n</histories>\n\n\nHuman: {{#sys.query#}}\n\nAssistant:",
"edition_type": "basic"
},
"stop": ["Human:"]
}
Expand Down
1 change: 1 addition & 0 deletions api/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ pytest~=8.1.1
pytest-benchmark~=4.0.0
pytest-env~=1.1.3
pytest-mock~=3.14.0
jinja2~=3.1.2