Skip to content

Commit

Permalink
refactor: split provider and register (#743)
Browse files Browse the repository at this point in the history
# Description

1. splitting the original `LLMProvider` into two interfaces, one for
storing registered funcs and another for LLM requesting.


```go
// The storage interface
type Register interface {
	ListToolCalls(md metadata.M) (map[uint32]ai.ToolCall, error)
	RegisterFunction(tag uint32, functionDefinition *ai.FunctionDefinition, connID uint64, md metadata.M) error
	UnregisterFunction(name string, connID uint64)
}
```
Additionally, provide the `register.SetRegister()` method to allow
modification of the default Register.

```go
// The request interface
type LLMProvider interface {
	// Name returns the name of the llm provider
	Name() string
	GetChatCompletions(prompt string, md metaddata.M) (*ai.InvokeResponse, error)
}
```

2. Add the `md` attribute to the `service` and inject the `md` using the
`ExchangeMetadataFunc` method. The signature of the
`ExchangeMetadataFunc` function is as follows:
```go
// ExchangeMetadataFunc is used to exchange metadata
type ExchangeMetadataFunc func(credential string) (metadata.M, error)

// DefaultExchangeMetadataFunc is the default ExchangeMetadataFunc, It returns an empty metadata.
func DefaultExchangeMetadataFunc(credential string) (metadata.M, error) {
	return metadata.M{}, nil
}
```

3. Declare the `HandleOverview` and `HandleInvoke` methods and mount the
service to `req.Context()`.

4. fix `ConnMiddleware` function bug in the `ai.go` file.
  • Loading branch information
woorui committed Mar 11, 2024
1 parent 2012a16 commit 9177893
Show file tree
Hide file tree
Showing 17 changed files with 459 additions and 1,171 deletions.
38 changes: 19 additions & 19 deletions cli/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,25 +114,25 @@ var serveCmd = &cobra.Command{
}

func registerAIProvider(aiConfig *ai.Config) error {
for name, provider := range aiConfig.Providers {
switch name {
case "azopenai":
ai.RegisterProvider(azopenai.NewProvider(
provider["api_key"],
provider["api_endpoint"],
provider["deployment_id"],
provider["api_version"],
))
case "gemini":
ai.RegisterProvider(gemini.NewProvider(provider["api_key"]))
case "openai":
ai.RegisterProvider(openai.NewProvider(provider["api_key"], provider["model"]))
default:
log.WarningStatusEvent(os.Stdout, "unknown provider: %s", name)
}
}
log.InfoStatusEvent(os.Stdout, "registered [%s] AI provider", name)
return nil
for name, provider := range aiConfig.Providers {
switch name {
case "azopenai":
ai.RegisterProvider(azopenai.NewProvider(
provider["api_key"],
provider["api_endpoint"],
provider["deployment_id"],
provider["api_version"],
))
case "gemini":
ai.RegisterProvider(gemini.NewProvider(provider["api_key"]))
case "openai":
ai.RegisterProvider(openai.NewProvider(provider["api_key"], provider["model"]))
default:
log.WarningStatusEvent(os.Stdout, "unknown provider: %s", name)
}
}
log.InfoStatusEvent(os.Stdout, "registered [%s] AI provider", name)
return nil
}

func init() {
Expand Down
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ require (
github.com/spf13/cobra v1.8.0
github.com/spf13/pflag v1.0.5
github.com/spf13/viper v1.18.2
github.com/stretchr/testify v1.8.4
github.com/stretchr/testify v1.9.0
github.com/tetratelabs/wazero v1.6.0
github.com/vmihailenco/msgpack/v5 v5.4.1
github.com/yomorun/y3 v1.0.5
Expand Down Expand Up @@ -64,7 +64,7 @@ require (
github.com/sourcegraph/conc v0.3.0 // indirect
github.com/spf13/afero v1.11.0 // indirect
github.com/spf13/cast v1.6.0 // indirect
github.com/stretchr/objx v0.5.1 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/teivah/onecontext v1.3.0 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
Expand Down
8 changes: 4 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -196,18 +196,18 @@ github.com/spf13/viper v1.18.2/go.mod h1:EKmWIqdnk5lOcmR72yw6hS+8OPYcwD0jteitLMV
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.1 h1:4VhoImhV/Bm0ToFkXFi8hXNXwpDRZ/ynw3amt82mzq0=
github.com/stretchr/objx v0.5.1/go.mod h1:/iHQpkQwBD6DLUmQ4pE+s1TXdob1mORJ4/UFdrifcy0=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA=
Expand Down
90 changes: 30 additions & 60 deletions pkg/bridge/ai/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,89 +11,59 @@ import (
"github.com/yomorun/yomo/core"
"github.com/yomorun/yomo/core/frame"
"github.com/yomorun/yomo/core/ylog"
"github.com/yomorun/yomo/pkg/bridge/ai/register"
"gopkg.in/yaml.v3"
)

var (
// ErrNotExistsProvider is the error when the provider does not exist
ErrNotExistsProvider = errors.New("llm provider does not exist")
// ErrNotImplementedService is the error when the service is not implemented
ErrNotImplementedService = errors.New("llm service is not implemented")
// ErrConfigNotFound is the error when the ai config was not found
ErrConfigNotFound = errors.New("ai config was not found")
// ErrConfigFormatError is the error when the ai config format is incorrect
ErrConfigFormatError = errors.New("ai config format is incorrect")
)

// ======================= Package Functions =======================

// RegisterFunction registers the tool function
func RegisterFunction(tag uint32, functionDefinition []byte, connID uint64) error {
provider, err := GetDefaultProvider()
if err != nil {
return err
}
fd := ai.FunctionDefinition{}
err = json.Unmarshal(functionDefinition, &fd)
if err != nil {
ylog.Error("unmarshal function definition", "error", err)
return err
}
return provider.RegisterFunction(tag, &fd, connID)
}

// UnregisterFunction unregister the tool function
func UnregisterFunction(name string, connID uint64) error {
provider, err := GetDefaultProvider()
if err != nil {
return err
}
return provider.UnregisterFunction(name, connID)
}

// ListToolCalls lists the AI tool calls
func ListToolCalls() (map[uint32]ai.ToolCall, error) {
provider, err := GetDefaultProvider()
if err != nil {
return nil, err
}
return provider.ListToolCalls()
}

// ConnMiddleware returns a ConnMiddleware that can be used to intercept the connection.
func ConnMiddleware(next core.ConnHandler) core.ConnHandler {
return func(conn *core.Connection) {
connMd := conn.Metadata().Clone()
defer func() {
next(conn)
register.UnregisterFunction(conn.ID(), connMd)
conn.Logger.Info("unregister ai function", "name", conn.Name(), "connID", conn.ID())
}()

// check sfn type and is ai function
if conn.ClientType() != core.ClientTypeStreamFunction {
next(conn)
return
}
for {
f, err := conn.FrameConn().ReadFrame()
// unregister ai function on any error
f, err := conn.FrameConn().ReadFrame()
// unregister ai function on any error
if err != nil {
conn.Logger.Error("failed to read frame on ai middleware", "err", err, "type", fmt.Sprintf("%T", err))
conn.Logger.Info("error type", "type", fmt.Sprintf("%T", err))
return
}
if ff, ok := f.(*frame.AIRegisterFunctionFrame); ok {
err := conn.FrameConn().WriteFrame(&frame.AIRegisterFunctionAckFrame{Name: ff.Name, Tag: ff.Tag})
if err != nil {
conn.Logger.Error("failed to read frame on ai middleware", "err", err)
conn.Logger.Info("error type", "type", fmt.Sprintf("%T", err))
name := conn.Name()
conn.Logger.Info("unregister ai function", "name", name, "connID", conn.ID())
UnregisterFunction(name, conn.ID())
conn.Logger.Error("failed to write ai RegisterFunctionAckFrame", "name", ff.Name, "tag", ff.Tag, "err", err)
return
}
if ff, ok := f.(*frame.AIRegisterFunctionFrame); ok {
err := conn.FrameConn().WriteFrame(&frame.AIRegisterFunctionAckFrame{Name: ff.Name, Tag: ff.Tag})
if err != nil {
conn.Logger.Error("failed to write ai RegisterFunctionAckFrame", "name", ff.Name, "tag", ff.Tag, "err", err)
return
}
// register ai function
err = RegisterFunction(ff.Tag, ff.Definition, conn.ID())
if err != nil {
conn.Logger.Error("failed to register ai function", "name", ff.Name, "tag", ff.Tag, "err", err)
return
}
conn.Logger.Info("register ai function success", "name", ff.Name, "tag", ff.Tag, "definition", string(ff.Definition))
// register ai function
fd := ai.FunctionDefinition{}
err = json.Unmarshal(ff.Definition, &fd)
if err != nil {
conn.Logger.Error("unmarshal function definition", "error", err)
return
}
next(conn)
err = register.RegisterFunction(ff.Tag, &fd, conn.ID(), connMd)
if err != nil {
conn.Logger.Error("failed to register ai function", "name", ff.Name, "tag", ff.Tag, "err", err)
return
}
conn.Logger.Info("register ai function success", "name", ff.Name, "tag", ff.Tag, "definition", string(ff.Definition))
}
}
}
Expand Down

0 comments on commit 9177893

Please sign in to comment.