Skip to content

Commit

Permalink
feat: add transID for ai function (#817)
Browse files Browse the repository at this point in the history
# Description

Add transID for ai function, each request will have an unique transID,
and we can use `FromTransIDContext` to retrieve it from the `ctx`.
  • Loading branch information
woorui committed May 17, 2024
1 parent b477013 commit 5c9d6ee
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 39 deletions.
1 change: 0 additions & 1 deletion ai/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ type OverviewResponse struct {

// InvokeRequest is the request from user to BasicAPIServer
type InvokeRequest struct {
ReqID string `json:"req_id"` // ReqID is the request id of the request
Prompt string `json:"prompt"` // Prompt is user input text for chat completion
IncludeCallStack bool `json:"include_call_stack"` // IncludeCallStack is the flag to include call stack in response
}
Expand Down
56 changes: 34 additions & 22 deletions pkg/bridge/ai/api_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ import (
"net/http"
"time"

gonanoid "github.com/matoous/go-nanoid/v2"
openai "github.com/sashabaranov/go-openai"
"github.com/yomorun/yomo/ai"
"github.com/yomorun/yomo/core/ylog"
"github.com/yomorun/yomo/pkg/id"
)

const (
Expand Down Expand Up @@ -97,7 +97,10 @@ func WithContextService(handler http.Handler, credential string, zipperAddr stri
}

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handler.ServeHTTP(w, r.WithContext(WithServiceContext(r.Context(), service)))
transID := id.New(32)
ctx := WithTransIDContext(r.Context(), transID)
ctx = WithServiceContext(ctx, service)
handler.ServeHTTP(w, r.WithContext(ctx))
})
}

Expand All @@ -119,18 +122,14 @@ func HandleOverview(w http.ResponseWriter, r *http.Request) {

// HandleInvoke is the handler for POST /invoke
func HandleInvoke(w http.ResponseWriter, r *http.Request) {
service := FromServiceContext(r.Context())
var (
ctx = r.Context()
service = FromServiceContext(ctx)
transID = FromTransIDContext(ctx)
)
defer r.Body.Close()
reqID, err := gonanoid.New(6)
if err != nil {
ylog.Error("generate reqID", "err", err.Error())
w.WriteHeader(http.StatusInternalServerError)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}

var req ai.InvokeRequest
req.ReqID = reqID

// decode the request
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
Expand All @@ -154,8 +153,8 @@ func HandleInvoke(w http.ResponseWriter, r *http.Request) {
errCh := make(chan error, 1)
go func(service *Service, req ai.InvokeRequest, baseSystemMessage string) {
// call llm to infer the function and arguments to be invoked
ylog.Debug(">> ai request", "reqID", req.ReqID, "prompt", req.Prompt)
res, err := service.GetInvoke(ctx, req.Prompt, baseSystemMessage, req.ReqID, req.IncludeCallStack)
ylog.Debug(">> ai request", "transID", transID, "prompt", req.Prompt)
res, err := service.GetInvoke(ctx, req.Prompt, baseSystemMessage, transID, req.IncludeCallStack)
if err != nil {
errCh <- err
} else {
Expand Down Expand Up @@ -183,16 +182,13 @@ func HandleInvoke(w http.ResponseWriter, r *http.Request) {

// HandleChatCompletions is the handler for POST /chat/completion
func HandleChatCompletions(w http.ResponseWriter, r *http.Request) {
service := FromServiceContext(r.Context())
var (
ctx = r.Context()
service = FromServiceContext(ctx)
transID = FromTransIDContext(ctx)
)
defer r.Body.Close()

reqID, err := gonanoid.New(6)
if err != nil {
ylog.Error("generate reqID", "err", err.Error())
RespondWithError(w, http.StatusInternalServerError, err)
return
}

var req openai.ChatCompletionRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
ylog.Error("decode request", "err", err.Error())
Expand All @@ -203,7 +199,7 @@ func HandleChatCompletions(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), 90*time.Second)
defer cancel()

if err := service.GetChatCompletions(ctx, req, reqID, w, false); err != nil {
if err := service.GetChatCompletions(ctx, req, transID, w, false); err != nil {
ylog.Error("invoke chat completions", "err", err.Error())
RespondWithError(w, http.StatusBadRequest, err)
return
Expand Down Expand Up @@ -248,3 +244,19 @@ func FromServiceContext(ctx context.Context) *Service {
}
return service
}

type transIDContextKey struct{}

// WithTransIDContext adds the transID to the request context
func WithTransIDContext(ctx context.Context, transID string) context.Context {
return context.WithValue(ctx, transIDContextKey{}, transID)
}

// FromTransIDContext returns the transID from the request context
func FromTransIDContext(ctx context.Context) string {
val, ok := ctx.Value(transIDContextKey{}).(string)
if !ok {
return ""
}
return val
}
35 changes: 19 additions & 16 deletions pkg/bridge/ai/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/yomorun/yomo/core/metadata"
"github.com/yomorun/yomo/core/ylog"
"github.com/yomorun/yomo/pkg/bridge/ai/register"
"github.com/yomorun/yomo/pkg/id"
"github.com/yomorun/yomo/serverless"
)

Expand Down Expand Up @@ -164,7 +165,7 @@ func (s *Service) createReducer() (yomo.StreamFunction, error) {
c, ok := s.sfnCallCache[reqID]
s.muCallCache.Unlock()
if !ok {
ylog.Error("[sfn-reducer] req_id not found", "req_id", reqID)
ylog.Error("[sfn-reducer] req_id not found", "trans_id", invoke.TransID, "req_id", reqID)
return
}

Expand Down Expand Up @@ -204,7 +205,7 @@ func (s *Service) GetOverview() (*ai.OverviewResponse, error) {
}

// GetInvoke returns the invoke response
func (s *Service) GetInvoke(ctx context.Context, userInstruction string, baseSystemMessage string, reqID string, includeCallStack bool) (*ai.InvokeResponse, error) {
func (s *Service) GetInvoke(ctx context.Context, userInstruction string, baseSystemMessage string, transID string, includeCallStack bool) (*ai.InvokeResponse, error) {
// read tools attached to the metadata
tcs, err := register.ListToolCalls(s.Metadata)
if err != nil {
Expand Down Expand Up @@ -243,8 +244,8 @@ func (s *Service) GetInvoke(ctx context.Context, userInstruction string, baseSys
"res_toolcalls", fmt.Sprintf("%+v", res.ToolCalls),
"res_assistant_msgs", fmt.Sprintf("%+v", res.AssistantMessage))

ylog.Debug(">> run function calls", "reqID", reqID, "res.ToolCalls", fmt.Sprintf("%+v", res.ToolCalls))
llmCalls, err := s.runFunctionCalls(res.ToolCalls, reqID)
ylog.Debug(">> run function calls", "transID", transID, "res.ToolCalls", fmt.Sprintf("%+v", res.ToolCalls))
llmCalls, err := s.runFunctionCalls(res.ToolCalls, transID, id.New(16))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -322,7 +323,7 @@ func overWriteSystemPrompt(req openai.ChatCompletionRequest, sysPrompt string) o
}

// GetChatCompletions returns the llm api response
func (s *Service) GetChatCompletions(ctx context.Context, req openai.ChatCompletionRequest, reqID string, w http.ResponseWriter, includeCallStack bool) error {
func (s *Service) GetChatCompletions(ctx context.Context, req openai.ChatCompletionRequest, transID string, w http.ResponseWriter, includeCallStack bool) error {
// 1. find all hosting tool sfn
tagTools, err := register.ListToolCalls(s.Metadata)
if err != nil {
Expand Down Expand Up @@ -386,7 +387,7 @@ func (s *Service) GetChatCompletions(ctx context.Context, req openai.ChatComplet
toolCallsMap[index] = item
}
isFunctionCall = true
} else {
} else if streamRes.Choices[0].FinishReason != openai.FinishReasonToolCalls {
_, _ = io.WriteString(w, "data: ")
_ = json.NewEncoder(w).Encode(streamRes)
_, _ = io.WriteString(w, "\n")
Expand Down Expand Up @@ -436,7 +437,8 @@ func (s *Service) GetChatCompletions(ctx context.Context, req openai.ChatComplet
}
}
// 6. run llm function calls
llmCalls, err := s.runFunctionCalls(fnCalls, reqID)
reqID := id.New(16)
llmCalls, err := s.runFunctionCalls(fnCalls, transID, reqID)
if err != nil {
return err
}
Expand Down Expand Up @@ -471,9 +473,6 @@ func (s *Service) GetChatCompletions(ctx context.Context, req openai.ChatComplet
if err != nil {
return err
}
if len(streamRes.Choices) == 0 {
continue
}
_, _ = io.WriteString(w, "data: ")
_ = json.NewEncoder(w).Encode(streamRes)
_, _ = io.WriteString(w, "\n")
Expand All @@ -491,22 +490,23 @@ func (s *Service) GetChatCompletions(ctx context.Context, req openai.ChatComplet
}

// run llm-sfn function calls
func (s *Service) runFunctionCalls(fns map[uint32][]*openai.ToolCall, reqID string) ([]ai.ToolMessage, error) {
func (s *Service) runFunctionCalls(fns map[uint32][]*openai.ToolCall, transID, reqID string) ([]ai.ToolMessage, error) {
if len(fns) == 0 {
return nil, nil
}

asyncCall := &sfnAsyncCall{
val: make(map[string]ai.ToolMessage),
}

s.muCallCache.Lock()
s.sfnCallCache[reqID] = asyncCall
s.muCallCache.Unlock()

for tag, tcs := range fns {
ylog.Debug("+++invoke toolCalls", "tag", tag, "len(toolCalls)", len(tcs), "reqID", reqID)
ylog.Debug("+++invoke toolCalls", "tag", tag, "len(toolCalls)", len(tcs), "transID", transID, "reqID", reqID)
for _, fn := range tcs {
err := s.fireLlmSfn(tag, fn, reqID)
err := s.fireLlmSfn(tag, fn, transID, reqID)
if err != nil {
ylog.Error("send data to zipper", "err", err.Error())
continue
Expand All @@ -533,19 +533,22 @@ func (s *Service) runFunctionCalls(fns map[uint32][]*openai.ToolCall, reqID stri
}

// fireLlmSfn fires the llm-sfn function call by s.source.Write()
func (s *Service) fireLlmSfn(tag uint32, fn *openai.ToolCall, reqID string) error {
func (s *Service) fireLlmSfn(tag uint32, fn *openai.ToolCall, transID, reqID string) error {
ylog.Info(
"+invoke func",
"tag", tag,
"transID", transID,
"reqID", reqID,
"toolCallID", fn.ID,
"function", fn.Function.Name,
"arguments", fn.Function.Arguments,
"reqID", reqID)
)
data := &ai.FunctionCall{
TransID: transID,
ReqID: reqID,
ToolCallID: fn.ID,
Arguments: fn.Function.Arguments,
FunctionName: fn.Function.Name,
Arguments: fn.Function.Arguments,
}
buf, err := data.Bytes()
if err != nil {
Expand Down

0 comments on commit 5c9d6ee

Please sign in to comment.