Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better ollama load balancing #1276

Draft
wants to merge 6 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Ollama URL for the backend to connect
# The path '/ollama' will be redirected to the specified backend URL
OLLAMA_BASE_URL='http://localhost:11434'
OLLAMA_BASE_URLS='http://localhost:11434;http://localhost:11434;http://localhost:11434'
OLLAMA_LB_WEIGHTS='1;2;3'
OLLAMA_LB_POLICY='weighted-round-robin'

OPENAI_API_BASE_URL=''
OPENAI_API_KEY=''
Expand Down
73 changes: 73 additions & 0 deletions backend/apps/ollama/load_balancer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from abc import ABC


class LoadBalancer:
def __init__(self, policy=None):
self.models_map = None # List to store server details
if policy is None:
policy = RoundRobinPolicy()
self.policy = policy # Load balancing policy

def set_model_map(self, models_map):
self.models_map = models_map

def get_server_idx_for_model(self, model: str):
"""
Get a server idx to handle requests for the given model using the specified policy.
"""
selected_server = self.policy.select_server(self.models_map, model)
print(f"selected_server: {selected_server}")
return selected_server


class LoadBalancingPolicy(ABC):
def select_server(self, models_map, model):
raise NotImplementedError


class RoundRobinPolicy(LoadBalancingPolicy):
def __init__(self):
self.current_index = 0

def select_server(self, models_map, model: str):
servers_supporting_model = list(set(models_map.get(model, {})["urls"]))

if not servers_supporting_model:
return None # No server supports the requested model

selected_server = servers_supporting_model[
self.current_index % len(servers_supporting_model)
]
self.current_index += 1
return selected_server


class WeightedRoundRobinPolicy(LoadBalancingPolicy):
def __init__(self):
self.weights = {}
self.current_index = 0

def set_weights(self, weights):
self.weights = weights

def select_server(self, models_map, model: str):
servers_supporting_model = list(set(models_map.get(model, {})["urls"]))

if not servers_supporting_model:
return None # No server supports the requested model

total_weight = sum(self.weights.values())
selected_index = self.current_index % total_weight
server_weights = {
server: self.weights[server] for server in servers_supporting_model
}

self.current_index += 1

for server_idx, weight in server_weights.items():
if selected_index < weight:
return server_idx
selected_index -= weight

# print("Falling back to first server")
return servers_supporting_model[0] # Fallback to the first server
74 changes: 63 additions & 11 deletions backend/apps/ollama/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,23 @@


from apps.web.models.users import Users
from apps.ollama.load_balancer import (
LoadBalancer,
RoundRobinPolicy,
WeightedRoundRobinPolicy,
)
from constants import ERROR_MESSAGES
from utils.utils import decode_token, get_current_user, get_admin_user

from utils.misc import calculate_sha256
from config import (
OLLAMA_BASE_URLS,
MODEL_FILTER_ENABLED,
MODEL_FILTER_LIST,
OLLAMA_LB_WEIGHTS,
OLLAMA_LB_POLICY,
)


from config import (
SRC_LOG_LEVELS,
Expand Down Expand Up @@ -59,9 +73,11 @@
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST

app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
app.state.OLLAMA_LB = LoadBalancer()
app.state.OLLAMA_LB_POLICY = "round-robin"
app.state.OLLAMA_LB_WEIGHTS = OLLAMA_LB_WEIGHTS
app.state.MODELS = {}


REQUEST_POOL = []


Expand All @@ -70,6 +86,22 @@
# least connections, or least response time for better resource utilization and performance optimization.


def get_ollama_load_balanced_url(model_name: str):
if app.state.OLLAMA_LB.models_map is None:
if OLLAMA_LB_POLICY == "weighted-round-robin":
lb_policy = WeightedRoundRobinPolicy()
lb_policy.set_weights(app.state.OLLAMA_LB_WEIGHTS)
else: # Fallback to "round-robin"
lb_policy = RoundRobinPolicy()
lb = LoadBalancer(lb_policy)
lb.set_model_map(app.state.MODELS)
app.state.OLLAMA_LB = lb

url_idx = app.state.OLLAMA_LB.get_server_idx_for_model(model_name)
# print(url_idx)
return url_idx


@app.middleware("http")
async def check_url(request: Request, call_next):
if len(app.state.MODELS) == 0:
Expand Down Expand Up @@ -104,6 +136,30 @@ async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin
return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS}


@app.get("/lb")
async def get_ollama_load_balancer(user=Depends(get_admin_user)):
return {
"OLLAMA_LB_POLICY": app.state.OLLAMA_LB_POLICY,
"OLLAMA_LB_WEIGHTS": app.state.OLLAMA_LB_WEIGHTS,
}


class LoadBalancerConfig(BaseModel):
policy: str
weights: Optional[List[int]]


@app.post("/lb/update")
async def update_ollama_load_balancer(
form_data: LoadBalancerConfig, user=Depends(get_admin_user)
):
app.state.OLLAMA_LB_POLICY = form_data.policy
app.state.OLLAMA_LB_WEIGHTS = form_data.weights

print(app.state.OLLAMA_LB_POLICY)
return {"OLLAMA_LB_POLICY": app.state.OLLAMA_LB_POLICY}


@app.get("/cancel/{request_id}")
async def cancel_ollama_request(request_id: str, user=Depends(get_current_user)):
if user:
Expand Down Expand Up @@ -205,9 +261,7 @@ async def get_ollama_tags(
@app.get("/api/version")
@app.get("/api/version/{url_idx}")
async def get_ollama_versions(url_idx: Optional[int] = None):

if url_idx == None:

# returns lowest version
tasks = [fetch_url(f"{url}/api/version") for url in app.state.OLLAMA_BASE_URLS]
responses = await asyncio.gather(*tasks)
Expand Down Expand Up @@ -566,7 +620,8 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_current_use
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
)

url_idx = random.choice(app.state.MODELS[form_data.name]["urls"])
print(f"form_data: {form_data.name}")
url_idx = get_ollama_load_balanced_url(form_data.name)
url = app.state.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")

Expand Down Expand Up @@ -612,7 +667,7 @@ async def generate_embeddings(
):
if url_idx == None:
if form_data.model in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[form_data.model]["urls"])
url_idx = get_ollama_load_balanced_url(form_data.model)
else:
raise HTTPException(
status_code=400,
Expand Down Expand Up @@ -669,10 +724,9 @@ async def generate_completion(
url_idx: Optional[int] = None,
user=Depends(get_current_user),
):

if url_idx == None:
if form_data.model in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[form_data.model]["urls"])
url_idx = get_ollama_load_balanced_url(form_data.model)
else:
raise HTTPException(
status_code=400,
Expand Down Expand Up @@ -767,10 +821,9 @@ async def generate_chat_completion(
url_idx: Optional[int] = None,
user=Depends(get_current_user),
):

if url_idx == None:
if form_data.model in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[form_data.model]["urls"])
url_idx = get_ollama_load_balanced_url(form_data.model)
else:
raise HTTPException(
status_code=400,
Expand Down Expand Up @@ -871,10 +924,9 @@ async def generate_openai_chat_completion(
url_idx: Optional[int] = None,
user=Depends(get_current_user),
):

if url_idx == None:
if form_data.model in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[form_data.model]["urls"])
url_idx = get_ollama_load_balanced_url(form_data.model)
else:
raise HTTPException(
status_code=400,
Expand Down
8 changes: 8 additions & 0 deletions backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,14 @@ def create_config_file(file_path):

OLLAMA_BASE_URLS = [url.strip() for url in OLLAMA_BASE_URLS.split(";")]

OLLAMA_LB_WEIGHTS = os.environ.get("OLLAMA_LB_WEIGHTS", "")
OLLAMA_LB_WEIGHTS = OLLAMA_LB_WEIGHTS if OLLAMA_LB_WEIGHTS != "" else OLLAMA_LB_WEIGHTS

OLLAMA_LB_WEIGHTS = [int(w.strip()) for w in OLLAMA_LB_WEIGHTS.split(";")]

OLLAMA_LB_POLICY = os.environ.get("OLLAMA_LB_POLICY", "round-robin")
OLLAMA_LB_POLICY = OLLAMA_LB_POLICY if OLLAMA_LB_POLICY != "" else OLLAMA_LB_POLICY


####################################
# OPENAI_API
Expand Down
Empty file modified backend/dev.sh
100644 → 100755
Empty file.
68 changes: 68 additions & 0 deletions src/lib/apis/ollama/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,74 @@ export const updateOllamaUrls = async (token: string = '', urls: string[]) => {
return res.OLLAMA_BASE_URLS;
};

export const getOllamaLoadBalancer = async (token: string = '') => {
let error = null;

const res = await fetch(`${OLLAMA_API_BASE_URL}/lb`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
if ('detail' in err) {
error = err.detail;
} else {
error = 'Server connection failed';
}
return null;
});

if (error) {
throw error;
}

return res;
};

// TODO: Add a type for lbConfig
export const updateOllamaLoadBalancer = async (token: string = '', lbConfig: any) => {
let error = null;

const res = await fetch(`${OLLAMA_API_BASE_URL}/lb/update`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
},
body: JSON.stringify({
policy: lbConfig.policy,
weights: lbConfig.weights
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
if ('detail' in err) {
error = err.detail;
} else {
error = 'Server connection failed';
}
return null;
});

if (error) {
throw error;
}

return res.OLLAMA_LB_POLICY;
};
export const getOllamaVersion = async (token: string = '') => {
let error = null;

Expand Down