Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve: extract method for safe loading yaml file and avoid using PyYaml's FullLoader #4031

Merged
merged 5 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 2 additions & 4 deletions api/core/model_runtime/model_providers/__base/ai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from abc import ABC, abstractmethod
from typing import Optional

import yaml

from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
from core.model_runtime.entities.model_entities import (
Expand All @@ -18,6 +16,7 @@
)
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
from core.tools.utils.yaml_utils import load_yaml_file
from core.utils.position_helper import get_position_map, sort_by_position_map


Expand Down Expand Up @@ -154,8 +153,7 @@ def predefined_models(self) -> list[AIModelEntity]:
# traverse all model_schema_yaml_paths
for model_schema_yaml_path in model_schema_yaml_paths:
# read yaml data from yaml file
with open(model_schema_yaml_path, encoding='utf-8') as f:
yaml_data = yaml.safe_load(f)
yaml_data = load_yaml_file(model_schema_yaml_path, ignore_error=True)

new_parameter_rules = []
for parameter_rule in yaml_data.get('parameter_rules', []):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import os
from abc import ABC, abstractmethod

import yaml

from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.entities.provider_entities import ProviderEntity
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.tools.utils.yaml_utils import load_yaml_file
from core.utils.module_import_helper import get_subclasses_from_module, import_module_from_source


Expand Down Expand Up @@ -44,10 +43,7 @@ def get_provider_schema(self) -> ProviderEntity:

# read provider schema from yaml file
yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
yaml_data = {}
if os.path.exists(yaml_path):
with open(yaml_path, encoding='utf-8') as f:
yaml_data = yaml.safe_load(f)
yaml_data = load_yaml_file(yaml_path, ignore_error=True)

try:
# yaml_data to entity
Expand Down
34 changes: 16 additions & 18 deletions api/core/tools/provider/builtin_tool_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from os import listdir, path
from typing import Any

from yaml import FullLoader, load

from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType
from core.tools.entities.user_entities import UserToolProviderCredentials
from core.tools.errors import (
Expand All @@ -15,6 +13,7 @@
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.tool.tool import Tool
from core.tools.utils.yaml_utils import load_yaml_file
from core.utils.module_import_helper import load_single_subclass_from_source


Expand All @@ -28,10 +27,9 @@ def __init__(self, **data: Any) -> None:
provider = self.__class__.__module__.split('.')[-1]
yaml_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.yaml')
try:
with open(yaml_path, 'rb') as f:
provider_yaml = load(f.read(), FullLoader)
except:
raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}')
provider_yaml = load_yaml_file(yaml_path)
except Exception as e:
raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}: {e}')

if 'credentials_for_provider' in provider_yaml and provider_yaml['credentials_for_provider'] is not None:
# set credentials name
Expand All @@ -58,18 +56,18 @@ def _get_builtin_tools(self) -> list[Tool]:
tool_files = list(filter(lambda x: x.endswith(".yaml") and not x.startswith("__"), listdir(tool_path)))
tools = []
for tool_file in tool_files:
with open(path.join(tool_path, tool_file), encoding='utf-8') as f:
# get tool name
tool_name = tool_file.split(".")[0]
tool = load(f.read(), FullLoader)
# get tool class, import the module
assistant_tool_class = load_single_subclass_from_source(
module_name=f'core.tools.provider.builtin.{provider}.tools.{tool_name}',
script_path=path.join(path.dirname(path.realpath(__file__)),
'builtin', provider, 'tools', f'{tool_name}.py'),
parent_type=BuiltinTool)
tool["identity"]["provider"] = provider
tools.append(assistant_tool_class(**tool))
# get tool name
tool_name = tool_file.split(".")[0]
tool = load_yaml_file(path.join(tool_path, tool_file))

# get tool class, import the module
assistant_tool_class = load_single_subclass_from_source(
module_name=f'core.tools.provider.builtin.{provider}.tools.{tool_name}',
script_path=path.join(path.dirname(path.realpath(__file__)),
'builtin', provider, 'tools', f'{tool_name}.py'),
parent_type=BuiltinTool)
tool["identity"]["provider"] = provider
tools.append(assistant_tool_class(**tool))

self.tools = tools
return tools
Expand Down
34 changes: 17 additions & 17 deletions api/core/tools/utils/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _deep_copy(self, credentials: dict[str, str]) -> dict[str, str]:
deep copy credentials
"""
return deepcopy(credentials)

def encrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]:
"""
encrypt tool credentials with tenant id
Expand All @@ -39,9 +39,9 @@ def encrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str
if field_name in credentials:
encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name])
credentials[field_name] = encrypted

return credentials

def mask_tool_credentials(self, credentials: dict[str, Any]) -> dict[str, Any]:
"""
mask tool credentials
Expand All @@ -58,7 +58,7 @@ def mask_tool_credentials(self, credentials: dict[str, Any]) -> dict[str, Any]:
if len(credentials[field_name]) > 6:
credentials[field_name] = \
credentials[field_name][:2] + \
'*' * (len(credentials[field_name]) - 4) +\
'*' * (len(credentials[field_name]) - 4) + \
credentials[field_name][-2:]
else:
credentials[field_name] = '*' * len(credentials[field_name])
Expand All @@ -72,7 +72,7 @@ def decrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str
return a deep copy of credentials with decrypted values
"""
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
tenant_id=self.tenant_id,
identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
cache_type=ToolProviderCredentialsCacheType.PROVIDER
)
Expand All @@ -92,10 +92,10 @@ def decrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str

cache.set(credentials)
return credentials

def delete_tool_credentials_cache(self):
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
tenant_id=self.tenant_id,
identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
cache_type=ToolProviderCredentialsCacheType.PROVIDER
)
Expand All @@ -116,7 +116,7 @@ def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
deep copy parameters
"""
return deepcopy(parameters)

def _merge_parameters(self) -> list[ToolParameter]:
"""
merge parameters
Expand All @@ -139,7 +139,7 @@ def _merge_parameters(self) -> list[ToolParameter]:
current_parameters.append(runtime_parameter)

return current_parameters

def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
mask tool parameters
Expand All @@ -157,13 +157,13 @@ def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
if len(parameters[parameter.name]) > 6:
parameters[parameter.name] = \
parameters[parameter.name][:2] + \
'*' * (len(parameters[parameter.name]) - 4) +\
'*' * (len(parameters[parameter.name]) - 4) + \
parameters[parameter.name][-2:]
else:
parameters[parameter.name] = '*' * len(parameters[parameter.name])

return parameters

def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
encrypt tool parameters with tenant id
Expand All @@ -180,17 +180,17 @@ def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
if parameter.name in parameters:
encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
parameters[parameter.name] = encrypted

return parameters

def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
decrypt tool parameters with tenant id

return a deep copy of parameters with decrypted values
"""
cache = ToolParameterCache(
tenant_id=self.tenant_id,
tenant_id=self.tenant_id,
provider=f'{self.provider_type}.{self.provider_name}',
tool_name=self.tool_runtime.identity.name,
cache_type=ToolParameterCacheType.PARAMETER,
Expand All @@ -212,15 +212,15 @@ def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
except:
pass

if has_secret_input:
cache.set(parameters)

return parameters

def delete_tool_parameters_cache(self):
cache = ToolParameterCache(
tenant_id=self.tenant_id,
tenant_id=self.tenant_id,
provider=f'{self.provider_type}.{self.provider_name}',
tool_name=self.tool_runtime.identity.name,
cache_type=ToolParameterCacheType.PARAMETER,
Expand Down
34 changes: 34 additions & 0 deletions api/core/tools/utils/yaml_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import logging
import os

import yaml
from yaml import YAMLError


def load_yaml_file(file_path: str, ignore_error: bool = False) -> dict:
"""
Safe loading a YAML file to a dict
:param file_path: the path of the YAML file
:param ignore_error:
if True, return empty dict if error occurs and the error will be logged in warning level
if False, raise error if error occurs
:return: a dict of the YAML content
"""
try:
if not file_path or not os.path.exists(file_path):
raise FileNotFoundError(f'Failed to load YAML file {file_path}: file not found')

with open(file_path, encoding='utf-8') as file:
try:
return yaml.safe_load(file)
except Exception as e:
raise YAMLError(f'Failed to load YAML file {file_path}: {e}')
except FileNotFoundError as e:
logging.debug(f'Failed to load YAML file {file_path}: {e}')
return {}
except Exception as e:
if ignore_error:
logging.warning(f'Failed to load YAML file {file_path}: {e}')
return {}
else:
raise e
27 changes: 10 additions & 17 deletions api/core/utils/position_helper.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import logging
import os
from collections import OrderedDict
from collections.abc import Callable
from typing import Any, AnyStr

import yaml
from core.tools.utils.yaml_utils import load_yaml_file


def get_position_map(
Expand All @@ -17,21 +16,15 @@ def get_position_map(
:param file_name: the YAML file name, default to '_position.yaml'
:return: a dict with name as key and index as value
"""
try:
position_file_name = os.path.join(folder_path, file_name)
if not os.path.exists(position_file_name):
return {}

with open(position_file_name, encoding='utf-8') as f:
positions = yaml.safe_load(f)
position_map = {}
for index, name in enumerate(positions):
if name and isinstance(name, str):
position_map[name.strip()] = index
return position_map
except:
logging.warning(f'Failed to load the YAML position file {folder_path}/{file_name}.')
return {}
position_file_name = os.path.join(folder_path, file_name)
positions = load_yaml_file(position_file_name, ignore_error=True)
position_map = {}
index = 0
for _, name in enumerate(positions):
if name and isinstance(name, str):
position_map[name.strip()] = index
index += 1
return position_map


def sort_by_position_map(
Expand Down
1 change: 1 addition & 0 deletions api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ select = [
"I002", # missing-required-import
"UP", # pyupgrade rules
"RUF019", # unnecessary-key-check
"S506", # unsafe-yaml-load
]
ignore = [
"F403", # undefined-local-with-import-star
Expand Down
Empty file.
34 changes: 34 additions & 0 deletions api/tests/unit_tests/utils/position_helper/test_position_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from textwrap import dedent

import pytest

from core.utils.position_helper import get_position_map


@pytest.fixture
def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str:
monkeypatch.chdir(tmp_path)
tmp_path.joinpath("example_positions.yaml").write_text(dedent(
"""\
- first
- second
# - commented
- third

- 9999999999999
- forth
"""))
return str(tmp_path)


def test_position_helper(prepare_example_positions_yaml):
position_map = get_position_map(
folder_path=prepare_example_positions_yaml,
file_name='example_positions.yaml')
assert len(position_map) == 4
assert position_map == {
'first': 0,
'second': 1,
'third': 2,
'forth': 3,
}
Empty file.