Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: custom function plugin support #798

Draft
wants to merge 3 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
172 changes: 172 additions & 0 deletions backend/apps/functions/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from pathlib import Path
import ast
import builtins


from fastapi import (
FastAPI,
Request,
Depends,
HTTPException,
status,
UploadFile,
File,
Form,
)

from fastapi.middleware.cors import CORSMiddleware
from apps.functions.security import ALLOWED_MODULES, ALLOWED_BUILTINS, custom_import
from utils.utils import get_current_user, get_admin_user


from config import FUNCTIONS_DIR
from constants import ERROR_MESSAGES


from pydantic import BaseModel

from typing import Optional

app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)


@app.get("/")
async def get_status():
return {"status": True}


class FunctionForm(BaseModel):
name: str
content: str


@app.post("/add")
def add_function(
form_data: FunctionForm,
user=Depends(get_admin_user),
):
try:
filename = f"{FUNCTIONS_DIR}/{form_data.name}.py"
if not Path(filename).exists():
with open(filename, "w") as file:
file.write(form_data.content)
return f"{form_data.name}.py" in list(
map(lambda x: x.name, Path(FUNCTIONS_DIR).rglob("./*"))
)
else:
raise Exception("Function already exists")
except Exception as e:
print(e)
return False


@app.post("/update")
def update_function(
form_data: FunctionForm,
user=Depends(get_admin_user),
):
try:
filename = f"{FUNCTIONS_DIR}/{form_data.name}.py"
if Path(filename).exists():
with open(filename, "w") as file:
file.write(form_data.content)
return f"{form_data.name}.py" in list(
map(lambda x: x.name, Path(FUNCTIONS_DIR).rglob("./*"))
)
else:
raise Exception("Function does not exist")
except Exception as e:
print(e)
return False


@app.get("/check/{function}")
def check_function(
function: str,
user=Depends(get_admin_user),
):
filename = f"{FUNCTIONS_DIR}/{function}.py"

# Check if the function file exists
if not Path(filename).is_file():
raise HTTPException(status_code=404, detail="Function not found")

# Read the code from the file
with open(filename, "r") as file:
code = file.read()

return {"name": function, "content": code}


@app.get("/list")
def list_functions(
user=Depends(get_admin_user),
):
files = list(map(lambda x: x.name, Path(FUNCTIONS_DIR).rglob("./*")))
return files


def validate_imports(code):
try:
tree = ast.parse(code)
except SyntaxError as e:
raise HTTPException(status_code=400, detail=f"Syntax error in function: {e}")

for node in ast.walk(tree):
if isinstance(node, ast.Import):
module_names = [alias.name for alias in node.names]
elif isinstance(node, ast.ImportFrom):
module_names = [node.module]
else:
continue

for name in module_names:
if name not in ALLOWED_MODULES:
raise HTTPException(
status_code=400, detail=f"Import of module {name} is not allowed"
)


@app.post("/exec/{function}")
def exec_function(
function: str,
kwargs: Optional[dict] = None,
user=Depends(get_current_user),
):
filename = f"{FUNCTIONS_DIR}/{function}.py"

# Check if the function file exists
if not Path(filename).is_file():
raise HTTPException(status_code=404, detail="Function not found")

# Read the code from the file
with open(filename, "r") as file:
code = file.read()

validate_imports(code)

try:
# Execute the code within a restricted namespace
namespace = {name: getattr(builtins, name) for name in ALLOWED_BUILTINS}
namespace["__import__"] = custom_import
exec(code, namespace)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Function: {e}")

# Check if the function exists in the namespace
if "main" not in namespace or not callable(namespace["main"]):
raise HTTPException(status_code=400, detail="Invalid function")

try:
# Execute the function with provided kwargs
result = namespace["main"](kwargs) if kwargs else namespace["main"]()
return result
except Exception as e:
raise HTTPException(status_code=400, detail=f"Function: {e}")
140 changes: 140 additions & 0 deletions backend/apps/functions/security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
ALLOWED_MODULES = {
"pydantic",
"math",
"json",
"time",
"datetime",
"requests",
} # Add allowed modules here


def custom_import(name, globals=None, locals=None, fromlist=(), level=0):
if name in ALLOWED_MODULES:
return __import__(name, globals, locals, fromlist, level)
raise ImportError(f"Import of module {name} is not allowed")


# Define a restricted set of builtins
ALLOWED_BUILTINS = {
"ArithmeticError",
"AssertionError",
"AttributeError",
"BaseException",
"BufferError",
"BytesWarning",
"DeprecationWarning",
"EOFError",
"Ellipsis",
"EnvironmentError",
"Exception",
"False",
"FloatingPointError",
"FutureWarning",
"GeneratorExit",
"IOError",
"ImportError",
"ImportWarning",
"IndentationError",
"IndexError",
"KeyError",
"KeyboardInterrupt",
"LookupError",
"MemoryError",
"NameError",
"None",
"NotImplemented",
"NotImplementedError",
"OSError",
"OverflowError",
"PendingDeprecationWarning",
"ReferenceError",
"RuntimeError",
"RuntimeWarning",
"StopIteration",
"SyntaxError",
"SyntaxWarning",
"SystemError",
"SystemExit",
"TabError",
"True",
"TypeError",
"UnboundLocalError",
"UnicodeDecodeError",
"UnicodeEncodeError",
"UnicodeError",
"UnicodeTranslateError",
"UnicodeWarning",
"UserWarning",
"ValueError",
"Warning",
"ZeroDivisionError",
"__build_class__",
"__debug__",
"__import__",
"abs",
"all",
"any",
"ascii",
"bin",
"bool",
"bytearray",
"bytes",
"callable",
"chr",
"classmethod",
"compile",
"complex",
"delattr",
"dict",
"dir",
"divmod",
"enumerate",
"eval",
"exec",
"filter",
"float",
"format",
"frozenset",
"getattr",
"globals",
"hasattr",
"hash",
"hex",
"id",
"input",
"int",
"isinstance",
"issubclass",
"iter",
"len",
"list",
"locals",
"map",
"max",
"memoryview",
"min",
"next",
"object",
"oct",
"open",
"ord",
"pow",
"print",
"property",
"range",
"repr",
"reversed",
"round",
"set",
"setattr",
"slice",
"sorted",
"staticmethod",
"str",
"sum",
"super",
"tuple",
"type",
"vars",
"zip",
}
7 changes: 7 additions & 0 deletions backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,13 @@ def parse_section(section):
Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)


####################################
# Functions DIR
####################################

FUNCTIONS_DIR = f"{DATA_DIR}/functions"
Path(FUNCTIONS_DIR).mkdir(parents=True, exist_ok=True)

####################################
# Docs DIR
####################################
Expand Down
3 changes: 2 additions & 1 deletion backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from apps.ollama.main import app as ollama_app
from apps.openai.main import app as openai_app
from apps.audio.main import app as audio_app
from apps.functions.main import app as functions_app
from apps.images.main import app as images_app
from apps.rag.main import app as rag_app

Expand Down Expand Up @@ -61,10 +62,10 @@ async def check_url(request: Request, call_next):

app.mount("/ollama/api", ollama_app)
app.mount("/openai/api", openai_app)

app.mount("/images/api/v1", images_app)
app.mount("/audio/api/v1", audio_app)
app.mount("/rag/api/v1", rag_app)
app.mount("/functions/api/v1", functions_app)


@app.get("/api/config")
Expand Down