Skip to content

Commit

Permalink
feat: add Google Gemini API provider (#2) (#742)
Browse files Browse the repository at this point in the history
## Related Docs: 

- https://ai.google.dev/docs/gemini_api_overview
- https://ai.google.dev/docs/function_calling

## How to use this provider

```yaml
bridge:
  ai:
    server:
      addr: localhost:8000
      provider: gemini

    providers:
      gemini:
        api_key: <your-api-key>
```

## Be careful that

1. The data format describes in api doc is different from actual api.
These can be found from the unit tests.
2. You can not set your function with `-` as this will break your api
response. Google will drop characters of the function name in its
response. So, I have to check this `-`, and replace to `_` if it
presence.
  • Loading branch information
fanweixiao committed Mar 6, 2024
1 parent b5317dc commit 2012a16
Show file tree
Hide file tree
Showing 8 changed files with 1,392 additions and 21 deletions.
41 changes: 21 additions & 20 deletions cli/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (

"github.com/yomorun/yomo/pkg/bridge/ai"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/azopenai"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/gemini"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/openai"
)

Expand Down Expand Up @@ -112,26 +113,26 @@ var serveCmd = &cobra.Command{
},
}

func registerAIProvider(aiConfig *ai.Config) {
// register the AI provider
for name, provider := range aiConfig.Providers {
// register LLM provider
switch name {
case "azopenai":
ai.RegisterProvider(azopenai.NewProvider(
provider["api_key"],
provider["api_endpoint"],
provider["deployment_id"],
provider["api_version"],
))
log.InfoStatusEvent(os.Stdout, "register [%s] AI provider", name)
// TODO: register other providers
}
// register the OpenAI provider
if name == "openai" {
ai.RegisterProvider(openai.NewProvider(provider["api_key"], provider["model"]))
}
}
func registerAIProvider(aiConfig *ai.Config) error {
for name, provider := range aiConfig.Providers {
switch name {
case "azopenai":
ai.RegisterProvider(azopenai.NewProvider(
provider["api_key"],
provider["api_endpoint"],
provider["deployment_id"],
provider["api_version"],
))
case "gemini":
ai.RegisterProvider(gemini.NewProvider(provider["api_key"]))
case "openai":
ai.RegisterProvider(openai.NewProvider(provider["api_key"], provider["model"]))
default:
log.WarningStatusEvent(os.Stdout, "unknown provider: %s", name)
}
}
log.InfoStatusEvent(os.Stdout, "registered [%s] AI provider", name)
return nil
}

func init() {
Expand Down
3 changes: 2 additions & 1 deletion example/10-ai/sfn-timezone-calculator/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type Parameter struct {
}

func Description() string {
return "Extract the source time and timezone information to `timeString` and `sourceTimezone`, extract the target timezone information to `targetTimezone`. the desired `timeString` format is 'YYYY-MM-DD HH:MM:SS'. the `sourceTimezone` and `targetTimezone` are in IANA Time Zone Database identifier format. The function will convert the time from the source timezone to the target timezone and return the converted time as a string in the format 'YYYY-MM-DD HH:MM:SS'. If you are not sure about the date value of `timeString`, set date value to '1900-01-01'"
return "Extract the source time and timezone information to `timeString` and `sourceTimezone`, extract the target timezone information to `targetTimezone`. the desired `timeString` format is `YYYY-MM-DD HH:MM:SS`. the `sourceTimezone` and `targetTimezone` are in IANA Time Zone Database identifier format. The function will convert the time from the source timezone to the target timezone and return the converted time as a string in the format `YYYY-MM-DD HH:MM:SS`. If you are not sure about the date value of `timeString`, set date value to `1900-01-01`"
}

func InputSchema() any {
Expand Down Expand Up @@ -78,6 +78,7 @@ func handler(ctx serverless.Context) {
targetTime, err := ConvertTimezone(msg.TimeString, msg.SourceTimezone, msg.TargetTimezone)
if err != nil {
slog.Error("[sfn] ConvertTimezone error", "err", err)
fcCtx.WriteErrors(err)
return
}

Expand Down
131 changes: 131 additions & 0 deletions pkg/bridge/ai/provider/gemini/model_converter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package gemini

import (
"encoding/json"

"github.com/yomorun/yomo/ai"
"github.com/yomorun/yomo/core/ylog"
)

func convertStandardToFunctionDeclaration(functionDefinition *ai.FunctionDefinition) *FunctionDeclaration {
if functionDefinition == nil {
return nil
}

return &FunctionDeclaration{
Name: functionDefinition.Name,
Description: functionDefinition.Description,
Parameters: convertStandardToFunctionParameters(functionDefinition.Parameters),
}
}

func convertFunctionDeclarationToStandard(functionDefinition *FunctionDeclaration) *ai.FunctionDefinition {
if functionDefinition == nil {
return nil
}

return &ai.FunctionDefinition{
Name: functionDefinition.Name,
Description: functionDefinition.Description,
Parameters: convertFunctionParametersToStandard(functionDefinition.Parameters),
}
}

func convertStandardToFunctionParameters(parameters *ai.FunctionParameters) *FunctionParameters {
if parameters == nil {
return nil
}

return &FunctionParameters{
Type: parameters.Type,
Properties: convertStandardToProperty(parameters.Properties),
Required: parameters.Required,
}
}

func convertFunctionParametersToStandard(parameters *FunctionParameters) *ai.FunctionParameters {
if parameters == nil {
return nil
}

return &ai.FunctionParameters{
Type: parameters.Type,
Properties: convertPropertyToStandard(parameters.Properties),
Required: parameters.Required,
}
}

func convertStandardToProperty(properties map[string]*ai.ParameterProperty) map[string]*Property {
if properties == nil {
return nil
}

result := make(map[string]*Property)
for k, v := range properties {
result[k] = &Property{
Type: v.Type,
Description: v.Description,
}
}
return result
}

func convertPropertyToStandard(properties map[string]*Property) map[string]*ai.ParameterProperty {
if properties == nil {
return nil
}

result := make(map[string]*ai.ParameterProperty)
for k, v := range properties {
result[k] = &ai.ParameterProperty{
Type: v.Type,
Description: v.Description,
}
}
return result
}

// generateJSONSchemaArguments generates the JSON schema arguments from OpenAPI compatible arguments
// https://ai.google.dev/docs/function_calling#how_it_works
func generateJSONSchemaArguments(args map[string]interface{}) string {
schema := make(map[string]interface{})

for k, v := range args {
schema[k] = v
}

schemaJSON, err := json.Marshal(schema)
if err != nil {
return ""
}

return string(schemaJSON)
}

func parseAPIResponseBody(respBody []byte) (*Response, error) {
var response *Response
err := json.Unmarshal(respBody, &response)
if err != nil {
ylog.Error("parseAPIResponseBody", "err", err, "respBody", string(respBody))
return nil, err
}
return response, nil
}

func parseToolCallFromResponse(response *Response) []ai.ToolCall {
calls := make([]ai.ToolCall, 0)
for _, candidate := range response.Candidates {
fn := candidate.Content.Parts[0].FunctionCall
fd := &ai.FunctionDefinition{
Name: fn.Name,
Arguments: generateJSONSchemaArguments(fn.Args),
}
call := ai.ToolCall{
ID: "cc-gemini-id",
Type: "cc-function",
Function: fd,
}
calls = append(calls, call)
}
return calls
}
43 changes: 43 additions & 0 deletions pkg/bridge/ai/provider/gemini/model_request.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package gemini

// RequestBody is the request body
type RequestBody struct {
Contents Contents `json:"contents"`
Tools []Tool `json:"tools"`
}

// Contents is the contents in RequestBody
type Contents struct {
Role string `json:"role"`
Parts Parts `json:"parts"`
}

// Parts is the contents.parts in RequestBody
type Parts struct {
Text string `json:"text"`
}

// Tool is the element of tools in RequestBody
type Tool struct {
FunctionDeclarations []*FunctionDeclaration `json:"function_declarations"`
}

// FunctionDeclaration is the element of Tool
type FunctionDeclaration struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters *FunctionParameters `json:"parameters"`
}

// FunctionParameters is the parameters of FunctionDeclaration
type FunctionParameters struct {
Type string `json:"type"`
Properties map[string]*Property `json:"properties"`
Required []string `json:"required"`
}

// Property is the element of ParameterProperties
type Property struct {
Type string `json:"type"`
Description string `json:"description"`
}
55 changes: 55 additions & 0 deletions pkg/bridge/ai/provider/gemini/model_response.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package gemini

type Response struct {
Candidates []Candidate `json:"candidates"`
PromptFeedback PromptFeedback `json:"promptFeedback"`
// UsageMetadata UsageMetadata `json:"usageMetadata"`
}

// Candidate is the element of Response
type Candidate struct {
Content *CandidateContent `json:"content"`
FinishReason string `json:"finishReason"`
Index int `json:"index"`
// SafetyRatings []CandidateSafetyRating `json:"safetyRatings"`
}

// CandidateContent is the content of Candidate
type CandidateContent struct {
Parts []*Part `json:"parts"`
Role string `json:"role"`
}

// Part is the element of CandidateContent
type Part struct {
FunctionCall *FunctionCall `json:"functionCall"`
}

// FunctionCall is the functionCall of Part
type FunctionCall struct {
Name string `json:"name"`
Args map[string]interface{} `json:"args"`
}

// CandidateSafetyRating is the safetyRatings of Candidate
type CandidateSafetyRating struct {
Category string `json:"category"`
Probability string `json:"probability"`
}

// UsageMetadata is the token usage in Response
type UsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount"`
TotalTokenCount int `json:"totalTokenCount"`
}

// SafetyRating is the element of PromptFeedback
type SafetyRating struct {
Category string `json:"category"`
Probability string `json:"probability"`
}

// PromptFeedback is the feedback of Prompt
type PromptFeedback struct {
SafetyRatings []*SafetyRating `json:"safetyRatings"`
}

0 comments on commit 2012a16

Please sign in to comment.