Skip to content

Commit

Permalink
✨ feat: support google tool calling
Browse files Browse the repository at this point in the history
  • Loading branch information
arvinxx committed May 9, 2024
1 parent 737db0e commit 317e658
Show file tree
Hide file tree
Showing 8 changed files with 316 additions and 61 deletions.
3 changes: 3 additions & 0 deletions src/config/modelProviders/google.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ const Google: ModelProviderCard = {
description: 'The best model for scaling across a wide range of tasks',
displayName: 'Gemini 1.0 Pro',
enabled: true,
functionCall: true,
id: 'gemini-pro',
maxOutput: 2048,
tokens: 30_720 + 2048,
Expand All @@ -47,6 +48,7 @@ const Google: ModelProviderCard = {
description:
'The best model for scaling across a wide range of tasks. This is a stable model that supports tuning.',
displayName: 'Gemini 1.0 Pro 001 (Tuning)',
functionCall: true,
id: 'gemini-1.0-pro-001',
maxOutput: 2048,
tokens: 30_720 + 2048,
Expand All @@ -71,6 +73,7 @@ const Google: ModelProviderCard = {
description: 'Mid-size multimodal model that supports up to 1 million tokens',
displayName: 'Gemini 1.5 Pro',
enabled: true,
functionCall: true,
id: 'gemini-1.5-pro-latest',
maxOutput: 8192,
tokens: 1_048_576 + 8192,
Expand Down
109 changes: 99 additions & 10 deletions src/libs/agent-runtime/google/index.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,29 @@
import { Content, GoogleGenerativeAI, Part } from '@google/generative-ai';
import { GoogleGenerativeAIStream, StreamingTextResponse } from 'ai';
import {
Content,
FunctionDeclaration,
FunctionDeclarationSchemaProperty,
FunctionDeclarationSchemaType,
Tool as GoogleFunctionCallTool,
GoogleGenerativeAI,
Part,
} from '@google/generative-ai';
import { JSONSchema7 } from 'json-schema';
import { transform } from 'lodash-es';

import { LobeRuntimeAI } from '../BaseAI';
import { AgentRuntimeErrorType, ILobeAgentRuntimeErrorType } from '../error';
import {
ChatCompetitionOptions,
ChatCompletionTool,
ChatStreamPayload,
OpenAIChatMessage,
UserMessageContentPart,
} from '../types';
import { ModelProvider } from '../types/type';
import { AgentRuntimeError } from '../utils/createError';
import { debugStream } from '../utils/debugStream';
import { StreamingResponse } from '../utils/response';
import { GoogleGenerativeAIStream, googleGenAIResultToStream } from '../utils/streams';
import { parseDataUri } from '../utils/uriParser';

enum HarmCategory {
Expand Down Expand Up @@ -42,7 +54,7 @@ export class LobeGoogleAI implements LobeRuntimeAI {

const contents = this.buildGoogleMessages(payload.messages, model);

const geminiStream = await this.client
const geminiStreamResult = await this.client
.getGenerativeModel(
{
generationConfig: {
Expand Down Expand Up @@ -74,19 +86,20 @@ export class LobeGoogleAI implements LobeRuntimeAI {
},
{ apiVersion: 'v1beta', baseUrl: this.baseURL },
)
.generateContentStream({ contents });

// Convert the response into a friendly text-stream
const stream = GoogleGenerativeAIStream(geminiStream, options?.callback);
.generateContentStream({ contents, tools: this.buildGoogleTools(payload.tools) });

const [debug, output] = stream.tee();
const googleStream = googleGenAIResultToStream(geminiStreamResult);
const [prod, useForDebug] = googleStream.tee();

if (process.env.DEBUG_GOOGLE_CHAT_COMPLETION === '1') {
debugStream(debug).catch(console.error);
debugStream(useForDebug).catch();
}

// Convert the response into a friendly text-stream
const stream = GoogleGenerativeAIStream(prod, options?.callback);

// Respond with the stream
return new StreamingTextResponse(output, { headers: options?.headers });
return StreamingResponse(stream, { headers: options?.headers });
} catch (e) {
const err = e as Error;

Expand Down Expand Up @@ -226,6 +239,82 @@ export class LobeGoogleAI implements LobeRuntimeAI {
return defaultError;
}
}

private buildGoogleTools(
tools: ChatCompletionTool[] | undefined,
): GoogleFunctionCallTool[] | undefined {
if (!tools || tools.length === 0) return;

return [
{
functionDeclarations: tools.map((tool) => {
const t = this.convertToolToGoogleTool(tool);
console.log('output Schema', t);
return t;
}),
},
];
}

private convertToolToGoogleTool = (tool: ChatCompletionTool): FunctionDeclaration => {
const functionDeclaration = tool.function;
const parameters = functionDeclaration.parameters;

console.log('input Schema', JSON.stringify(parameters, null, 2));

return {
description: functionDeclaration.description,
name: functionDeclaration.name,
parameters: {
description: parameters?.description,
properties: transform(parameters?.properties, (result, value, key: string) => {
result[key] = this.convertSchemaObject(value as JSONSchema7);
}),
required: parameters?.required,
type: FunctionDeclarationSchemaType.OBJECT,
},
};
};

private convertSchemaObject(schema: JSONSchema7): FunctionDeclarationSchemaProperty {
console.log('input:', schema);

switch (schema.type) {
default:
case 'object': {
return {
...schema,
properties: Object.fromEntries(
Object.entries(schema.properties || {}).map(([key, value]) => [
key,
this.convertSchemaObject(value as JSONSchema7),
]),
),
type: FunctionDeclarationSchemaType.OBJECT,
} as any;
}

case 'array': {
return {
...schema,
items: this.convertSchemaObject(schema.items as JSONSchema7),
type: FunctionDeclarationSchemaType.ARRAY,
} as any;
}

case 'string': {
return { ...schema, type: FunctionDeclarationSchemaType.STRING } as any;
}

case 'number': {
return { ...schema, type: FunctionDeclarationSchemaType.NUMBER } as any;
}

case 'boolean': {
return { ...schema, type: FunctionDeclarationSchemaType.BOOLEAN } as any;
}
}
}
}

export default LobeGoogleAI;
Expand Down
48 changes: 39 additions & 9 deletions src/libs/agent-runtime/utils/debugStream.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,48 @@
// no need to introduce a package to get the current time as this module is just a debug utility
const getTime = () => {
const date = new Date();
return `${date.getFullYear()}-${date.getDate()}-${date.getDay()} ${date.getHours()}:${date.getMinutes()}:${date.getSeconds()}.${date.getMilliseconds()}`;
};

export const debugStream = async (stream: ReadableStream) => {
let done = false;
let finished = false;
let chunk = 0;
let chunkValue: any;
const decoder = new TextDecoder();

const reader = stream.getReader();
while (!done) {
const { value, done: _done } = await reader.read();
const chunkValue = decoder.decode(value, { stream: true });
if (!_done) {
console.log(`[chunk ${chunk}]`);

console.log(`[stream start] ${getTime()}`);

while (!finished) {
try {
const { value, done } = await reader.read();

if (done) {
console.log(`[stream finished] total chunks: ${chunk}\n`);
finished = true;
break;
}

chunkValue = value;

// if the value is ArrayBuffer, we need to decode it
if ('byteLength' in value) {
chunkValue = decoder.decode(value, { stream: true });
} else if (typeof value !== 'string') {
chunkValue = JSON.stringify(value);
}

console.log(`[chunk ${chunk}] ${getTime()}`);
console.log(chunkValue);
}
console.log(`\n`);

done = _done;
chunk++;
finished = done;
chunk++;
} catch (e) {
finished = true;
console.error('[debugStream error]', e);
console.error('[error chunk value:]', chunkValue);
}
}
};
110 changes: 110 additions & 0 deletions src/libs/agent-runtime/utils/streams/google-ai.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import {
EnhancedGenerateContentResponse,
GenerateContentStreamResult,
} from '@google/generative-ai';
import { createCallbacksTransformer, readableFromAsyncIterable } from 'ai';

import { nanoid } from '@/utils/uuid';

import { ChatStreamCallbacks } from '../../types';
import {
StreamProtocolChunk,
StreamToolCallChunk,
chatStreamable,
generateToolCallId,
} from './protocol';

const transformGoogleGenerativeAIStream = (
chunk: EnhancedGenerateContentResponse,
): StreamProtocolChunk => {
// maybe need another structure to add support for multiple choices
const functionCalls = chunk.functionCalls();

if (functionCalls) {
return {
data: functionCalls.map(
(value, index): StreamToolCallChunk => ({
function: {
arguments: JSON.stringify(value.args),
name: value.name,
},
id: generateToolCallId(index, value.name),
index: index,
type: 'function',
}),
),
id: nanoid(),
type: 'tool_calls',
};
}
const text = chunk.text();

return {
data: text,
id: nanoid(),
type: 'text',
};
// if (typeof item.delta?.content === 'string') {
// return { data: item.delta.content, id: chunk.id, type: 'text' };
// }
//
// if (item.delta?.tool_calls) {
// return {
// data: item.delta.tool_calls.map((value, index) => ({
// ...value,
//
// // mistral's tool calling don't have index and function field, it's data like:
// // [{"id":"xbhnmTtY7","function":{"name":"lobe-image-designer____text2image____builtin","arguments":"{\"prompts\": [\"A photo of a small, fluffy dog with a playful expression and wagging tail.\", \"A watercolor painting of a small, energetic dog with a glossy coat and bright eyes.\", \"A vector illustration of a small, adorable dog with a short snout and perky ears.\", \"A drawing of a small, scruffy dog with a mischievous grin and a wagging tail.\"], \"quality\": \"standard\", \"seeds\": [123456, 654321, 111222, 333444], \"size\": \"1024x1024\", \"style\": \"vivid\"}"}}]
//
// // minimax's tool calling don't have index field, it's data like:
// // [{"id":"call_function_4752059746","type":"function","function":{"name":"lobe-image-designer____text2image____builtin","arguments":"{\"prompts\": [\"一个流浪的地球,背景是浩瀚"}}]
//
// // so we need to add these default values
// index: typeof value.index !== 'undefined' ? value.index : index,
// type: value.type || 'function',
// })),
// id: chunk.id,
// type: 'tool_calls',
// };
// }
//
// // 给定结束原因
// if (item.finish_reason) {
// return { data: item.finish_reason, id: chunk.id, type: 'stop' };
// }
//
// if (item.delta.content === null) {
// return { data: item.delta, id: chunk.id, type: 'data' };
// }
//
// // 其余情况下,返回 delta 和 index
// return {
// data: { delta: item.delta, id: chunk.id, index: item.index },
// id: chunk.id,
// type: 'data',
// };
};

// only use for debug
export const googleGenAIResultToStream = (stream: GenerateContentStreamResult) => {
// make the response to the streamable format
return readableFromAsyncIterable(chatStreamable(stream.stream));
};

export const GoogleGenerativeAIStream = (
rawStream: ReadableStream<EnhancedGenerateContentResponse>,
callbacks?: ChatStreamCallbacks,
) =>
rawStream
.pipeThrough(
new TransformStream({
transform: (chunk, controller) => {
const { type, id, data } = transformGoogleGenerativeAIStream(chunk);

controller.enqueue(`id: ${id}\n`);
controller.enqueue(`event: ${type}\n`);
controller.enqueue(`data: ${JSON.stringify(data)}\n\n`);
},
}),
)
.pipeThrough(createCallbacksTransformer(callbacks));
1 change: 1 addition & 0 deletions src/libs/agent-runtime/utils/streams/index.ts
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
export * from './google-ai';
export * from './openai';
export * from './protocol';
2 changes: 1 addition & 1 deletion src/libs/agent-runtime/utils/streams/minimax.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { createCallbacksTransformer } from 'ai';
import OpenAI from 'openai';

import { ChatStreamCallbacks } from '../../types';
import { transformOpenAIStream } from './protocol';
import { transformOpenAIStream } from './openai';

const unit8ArrayToJSONChunk = (unit8Array: Uint8Array): OpenAI.ChatCompletionChunk => {
const decoder = new TextDecoder();
Expand Down

0 comments on commit 317e658

Please sign in to comment.