Skip to content

Commit

Permalink
✨ feat: support minimax tool calling
Browse files Browse the repository at this point in the history
  • Loading branch information
arvinxx committed May 8, 2024
1 parent c9d7fdf commit f0a0764
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 52 deletions.
3 changes: 3 additions & 0 deletions src/config/modelProviders/minimax.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,23 @@ const Minimax: ModelProviderCard = {
description: '复杂场景,例如应用题计算、科学计算等场景',
displayName: 'abab6.5',
enabled: true,
functionCall: true,
id: 'abab6.5-chat',
tokens: 8192,
},
{
description: '通用场景',
displayName: 'abab6.5s',
enabled: true,
functionCall: true,
id: 'abab6.5s-chat',
tokens: 245_760,
},
{
description: '更复杂的格式化文本生成',
displayName: 'abab6',
enabled: true,
functionCall: true,
id: 'abab6-chat',
tokens: 32_768,
},
Expand Down
2 changes: 1 addition & 1 deletion src/config/server/provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ export const getProviderConfig = () => {
AWS_ACCESS_KEY_ID: AWS_ACCESS_KEY_ID,
AWS_SECRET_ACCESS_KEY: process.env.AWS_SECRET_ACCESS_KEY || '',

ENABLE_OLLAMA: process.env.ENABLE_OLLAMA as unknown as boolean,
ENABLE_OLLAMA: Boolean(process.env.ENABLE_OLLAMA),
OLLAMA_PROXY_URL: process.env.OLLAMA_PROXY_URL || '',
OLLAMA_MODEL_LIST: process.env.OLLAMA_MODEL_LIST || process.env.OLLAMA_CUSTOM_MODELS,
};
Expand Down
44 changes: 13 additions & 31 deletions src/libs/agent-runtime/minimax/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { StreamingTextResponse } from 'ai';
import { isEmpty } from 'lodash-es';
import OpenAI from 'openai';

import { debugStream } from '@/libs/agent-runtime/utils/debugStream';
import { MinimaxStream } from '@/libs/agent-runtime/utils/streams/minimax';

import { LobeRuntimeAI } from '../BaseAI';
import { AgentRuntimeErrorType } from '../error';
Expand All @@ -13,6 +13,8 @@ import {
ModelProvider,
} from '../types';
import { AgentRuntimeError } from '../utils/createError';
import { debugStream } from '../utils/debugStream';
import { StreamingResponse } from '../utils/response';

interface MinimaxBaseResponse {
base_resp?: {
Expand Down Expand Up @@ -74,13 +76,6 @@ export class LobeMinimaxAI implements LobeRuntimeAI {
options?: ChatCompetitionOptions,
): Promise<StreamingTextResponse> {
try {
let streamController: ReadableStreamDefaultController | undefined;
const readableStream = new ReadableStream({
start(controller) {
streamController = controller;
},
});

const response = await fetch('https://api.minimax.chat/v1/text/chatcompletion_v2', {
body: JSON.stringify(this.buildCompletionsParams(payload)),
headers: {
Expand All @@ -107,12 +102,10 @@ export class LobeMinimaxAI implements LobeRuntimeAI {
debugStream(debug).catch(console.error);
}

this.parseResponse(prod.getReader(), streamController);

// wait for the first response, and throw error if minix returns an error
await this.parseFirstResponse(prod2.getReader());

return new StreamingTextResponse(readableStream, { headers: options?.headers });
return StreamingResponse(MinimaxStream(prod), { headers: options?.headers });
} catch (error) {
console.log('error', error);
const err = error as Error | ChatCompletionErrorPayload;
Expand Down Expand Up @@ -154,30 +147,19 @@ export class LobeMinimaxAI implements LobeRuntimeAI {
max_tokens: this.getMaxTokens(payload.model),
stream: true,
temperature: temperature === 0 ? undefined : temperature,

tools: params.tools?.map((tool) => ({
function: {
description: tool.function.description,
name: tool.function.name,
parameters: JSON.stringify(tool.function.parameters),
},
type: 'function',
})),
top_p: top_p === 0 ? undefined : top_p,
};
}

private async parseResponse(
reader: ReadableStreamDefaultReader<Uint8Array>,
streamController: ReadableStreamDefaultController | undefined,
) {
const encoder = new TextEncoder();
const decoder = new TextDecoder();
let done = false;

while (!done) {
const { value, done: doneReading } = await reader.read();
done = doneReading;
const chunkValue = decoder.decode(value, { stream: true });
const data = parseMinimaxResponse(chunkValue);
const text = data?.choices?.at(0)?.delta?.content || undefined;
streamController?.enqueue(encoder.encode(text));
}

streamController?.close();
}

private async parseFirstResponse(reader: ReadableStreamDefaultReader<Uint8Array>) {
const decoder = new TextDecoder();

Expand Down
39 changes: 39 additions & 0 deletions src/libs/agent-runtime/utils/streams/minimax.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import { createCallbacksTransformer } from 'ai';
import OpenAI from 'openai';

import { ChatStreamCallbacks } from '../../types';
import { transformOpenAIStream } from './protocol';

const unit8ArrayToJSONChunk = (unit8Array: Uint8Array): OpenAI.ChatCompletionChunk => {
const decoder = new TextDecoder();

let chunkValue = decoder.decode(unit8Array, { stream: true });

// chunkValue example:
// data: {"id":"028a65377137d57aaceeffddf48ae99f","choices":[{"finish_reason":"tool_calls","index":0,"delta":{"role":"assistant","tool_calls":[{"id":"call_function_7371372822","type":"function","function":{"name":"realtime-weather____fetchCurrentWeather","arguments":"{\"city\": [\"杭州\", \"北京\"]}"}}]}}],"created":155511,"model":"abab6.5s-chat","object":"chat.completion.chunk"}

// so we need to remove `data:` prefix and then parse it as JSON
if (chunkValue.startsWith('data:')) {
chunkValue = chunkValue.slice(5).trim();
}

return JSON.parse(chunkValue);
};

export const MinimaxStream = (stream: ReadableStream, callbacks?: ChatStreamCallbacks) => {
return stream
.pipeThrough(
new TransformStream({
transform: (buffer, controller) => {
const chunk = unit8ArrayToJSONChunk(buffer);

const { type, id, data } = transformOpenAIStream(chunk);

controller.enqueue(`id: ${id}\n`);
controller.enqueue(`event: ${type}\n`);
controller.enqueue(`data: ${JSON.stringify(data)}\n\n`);
},
}),
)
.pipeThrough(createCallbacksTransformer(callbacks));
};
11 changes: 9 additions & 2 deletions src/libs/agent-runtime/utils/streams/openai.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
import { createCallbacksTransformer } from 'ai';
import { createCallbacksTransformer, readableFromAsyncIterable } from 'ai';
import OpenAI from 'openai';
import type { Stream } from 'openai/streaming';

import { ChatStreamCallbacks } from '../../types';
import { transformOpenAIStream } from './protocol';

const chatStreamable = async function* (stream: AsyncIterable<OpenAI.ChatCompletionChunk>) {
for await (const response of stream) {
yield response;
}
};

export const OpenAIStream = (
stream: Stream<OpenAI.ChatCompletionChunk> | ReadableStream,
callbacks?: ChatStreamCallbacks,
) => {
const readableStream = stream instanceof ReadableStream ? stream : stream.toReadableStream();
const readableStream =
stream instanceof ReadableStream ? stream : readableFromAsyncIterable(chatStreamable(stream));

return readableStream
.pipeThrough(
Expand Down
31 changes: 13 additions & 18 deletions src/libs/agent-runtime/utils/streams/protocol.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,36 +6,31 @@ interface StreamProtocolChunk {
type: 'text' | 'tool_calls' | 'data' | 'stop';
}

export const transformOpenAIStream = (chunk: Uint8Array): StreamProtocolChunk => {
const decoder = new TextDecoder();

const chunkValue = decoder.decode(chunk, { stream: true });
const jsonValue: OpenAI.ChatCompletionChunk = JSON.parse(chunkValue);

export const transformOpenAIStream = (chunk: OpenAI.ChatCompletionChunk): StreamProtocolChunk => {
// maybe need another structure to add support for multiple choices
const item = jsonValue.choices[0];

if (typeof item.delta.content === 'string') {
return { data: item.delta.content, id: jsonValue.id, type: 'text' };
}
const item = chunk.choices[0];

if (item.delta.tool_calls) {
return { data: item.delta.tool_calls, id: jsonValue.id, type: 'tool_calls' };
if (typeof item.delta?.content === 'string') {
return { data: item.delta.content, id: chunk.id, type: 'text' };
}

if (item.delta.content === null) {
return { data: item.delta, id: jsonValue.id, type: 'data' };
if (item.delta?.tool_calls) {
return { data: item.delta.tool_calls, id: chunk.id, type: 'tool_calls' };
}

// 给定结束原因
if (item.finish_reason) {
return { data: item.finish_reason, id: jsonValue.id, type: 'stop' };
return { data: item.finish_reason, id: chunk.id, type: 'stop' };
}

if (item.delta.content === null) {
return { data: item.delta, id: chunk.id, type: 'data' };
}

// 其余情况下,返回 delta 和 index
return {
data: { delta: item.delta, id: jsonValue.id, index: item.index },
id: jsonValue.id,
data: { delta: item.delta, id: chunk.id, index: item.index },
id: chunk.id,
type: 'data',
};
};

0 comments on commit f0a0764

Please sign in to comment.