Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sap generative ai hub #737

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 14 additions & 2 deletions core/config/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ declare global {
apiType?: string;
region?: string;
projectId?: string;
resourceGroup?: string;
authURL?: string;
clientID?: string;
clientSecret?: string;

_fetch?: (input: any, init?: any) => Promise<any>;

Expand Down Expand Up @@ -198,6 +202,12 @@ declare global {
// GCP Options
region?: string;
projectId?: string;

// SAP Gen AI Core options
resourceGroup?: string;
authURL?: string;
clientID?: string;
clientSecret?: string;
}
type RequireAtLeastOne<T, Keys extends keyof T = keyof T> = Pick<
T,
Expand Down Expand Up @@ -231,6 +241,7 @@ declare global {

// IDE


export interface DiffLine {
type: "new" | "old" | "same";
line: string;
Expand Down Expand Up @@ -342,6 +353,7 @@ declare global {
| "gemini"
| "mistral"
| "bedrock"
| "sap-gen-ai-hub"
| "deepinfra";

export type ModelName =
Expand Down Expand Up @@ -509,8 +521,8 @@ declare global {
disableIndexing?: boolean;
userToken?: string;
}


}

export {};
Expand Down
12 changes: 12 additions & 0 deletions core/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ export interface ILLM extends LLMOptions {
apiType?: string;
region?: string;
projectId?: string;
// SAP Gen AI Core options
resourceGroup?: string;
authURL?: string;
clientID?: string;
clientSecret?: string;

_fetch?: (input: any, init?: any) => Promise<any>;

Expand Down Expand Up @@ -196,6 +201,12 @@ export interface LLMOptions {
// GCP Options
region?: string;
projectId?: string;

// SAP Gen AI Core options
resourceGroup?: string;
authURL?: string;
clientID?: string;
clientSecret?: string;
}
type RequireAtLeastOne<T, Keys extends keyof T = keyof T> = Pick<
T,
Expand Down Expand Up @@ -340,6 +351,7 @@ type ModelProvider =
| "gemini"
| "mistral"
| "bedrock"
| "sap-gen-ai-hub"
| "deepinfra";

export type ModelName =
Expand Down
15 changes: 14 additions & 1 deletion core/llm/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ function autodetectTemplateType(model: string): TemplateType | undefined {
if (lower.includes("mistral")) {
return "llama2";
}

if (lower.includes("deepseek")) {
return "deepseek";
}
Expand Down Expand Up @@ -202,6 +201,12 @@ export abstract class BaseLLM implements ILLM {
region?: string;
projectId?: string;

// SAP Gen AI Core options
resourceGroup?: string;
authURL?: string;
clientID?: string;
clientSecret?: string;

private _llmOptions: LLMOptions;

constructor(options: LLMOptions) {
Expand Down Expand Up @@ -252,6 +257,14 @@ export abstract class BaseLLM implements ILLM {
this.apiType = options.apiType;
this.region = options.region;
this.projectId = options.projectId;

// SAP Gen AI Core options
this.resourceGroup = options.resourceGroup;
this.authURL = options.authURL;
this.clientID = options.clientID;
this.clientSecret = options.clientSecret;


}

private _compileChatMessages(
Expand Down
33 changes: 19 additions & 14 deletions core/llm/llms/OpenAI.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class OpenAI extends BaseLLM {

return completion;
}
private _getCompletionUrl() {
protected _getCompletionUrl() {
if (this.apiType === "azure") {
return `${this.apiBase}/openai/deployments/${this.engine}/completions?api-version=${this.apiVersion}`;
} else {
Expand All @@ -79,6 +79,15 @@ class OpenAI extends BaseLLM {
}
}

protected async _getRequestHeaders(): Promise<Record<string, string>> {
return {
"Content-Type": "application/json",
Authorization: `Bearer ${this.apiKey}`,
"api-key": this.apiKey || "", // For Azure
};
}


protected async *_streamComplete(
prompt: string,
options: CompletionOptions
Expand All @@ -95,13 +104,11 @@ class OpenAI extends BaseLLM {
prompt: string,
options: CompletionOptions
): AsyncGenerator<string> {
const response = await this.fetch(this._getCompletionUrl(), {
const header = await this._getRequestHeaders();
const url = this._getCompletionUrl();
const response = await this.fetch(url, {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${this.apiKey}`,
"api-key": this.apiKey || "", // For Azure
},
headers: header,
body: JSON.stringify({
...{
prompt,
Expand All @@ -124,7 +131,7 @@ class OpenAI extends BaseLLM {
}
}

private _getChatUrl() {
protected _getChatUrl() {
if (this.apiType === "azure") {
return `${this.apiBase}/openai/deployments/${this.engine}/chat/completions?api-version=${this.apiVersion}`;
} else {
Expand Down Expand Up @@ -163,13 +170,11 @@ class OpenAI extends BaseLLM {
return;
}

const response = await this.fetch(this._getChatUrl(), {
const header = await this._getRequestHeaders();
const url = this._getChatUrl();
const response = await this.fetch(url, {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${this.apiKey}`,
"api-key": this.apiKey || "", // For Azure
},
headers: header,
body: JSON.stringify({
...this._convertArgs(options, messages),
stream: true,
Expand Down
114 changes: 114 additions & 0 deletions core/llm/llms/SAPGenAIHub.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import { ModelProvider } from "../..";
import OpenAI from "./OpenAI";

class SAPGenAIHub extends OpenAI {
private tokenCache: { token: string; expiry: number } | null = null;

static providerName: ModelProvider = "sap-gen-ai-hub";

private _getTokenParams(): { authURL: string; clientID: string; clientSecret: string } {
if (!this.authURL || !this.clientID || !this.clientSecret) {
throw new Error("Authentication parameters (authURL, clientID, clientSecret) are undefined");
}
return {
authURL: this.authURL.endsWith("/oauth/token") ? this.authURL : `${this.authURL}/oauth/token`,
clientID: this.clientID,
clientSecret: this.clientSecret,
};
}

private async fetchWithTimeout(url: string, options: RequestInit, timeout: number): Promise<Response> {
return new Promise((resolve, reject) => {
const timer = setTimeout(() => reject(new Error('Request timed out')), timeout);

fetch(url, options)
.then(response => {
clearTimeout(timer);
resolve(response);
})
.catch(err => {
clearTimeout(timer);
reject(err);
});
});
}

private async fetchOAuthToken(): Promise<string> {
const params = this._getTokenParams();
const response = await this.fetchWithTimeout(params.authURL, {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: new URLSearchParams({
client_id: params.clientID,
client_secret: params.clientSecret,
grant_type: 'client_credentials',
}).toString(),
}, 10000); // 10-second timeout

const data = await response.json();
return data.access_token;
}

private async ensureToken(): Promise<void> {
if (!this.tokenCache || Date.now() >= this.tokenCache.expiry) {
const token = await this.fetchOAuthToken();
const expiry = Date.now() + 3600 * 1000; // Consider making this configurable
this.tokenCache = { token, expiry };
}
}

protected async _getRequestHeaders(): Promise<Record<string, string>> {
await this.ensureToken();
const header: Record<string, string> = {
"Content-Type": "application/json",
Authorization: `Bearer ${this.tokenCache?.token}`,
"api-key": this.apiKey || "",
"AI-Resource-Group": this.resourceGroup || "default",
};
return header;
}

protected _getChatUrl() {
if (this.apiType === "azure") {
return `${this.apiBase}/chat/completions?api-version=${this.apiVersion}`;
} else {
let url = this.apiBase;
if (!url) {
throw new Error(
"No API base URL provided. Please set the 'apiBase' option in config.json"
);
}
if (url.endsWith("/")) {
url = url.slice(0, -1);
}

if (!url.endsWith("/v1")) {
url += "/v1";
}
return url + "/chat/completions";
}
}

protected _getCompletionUrl() {
if (this.apiType === "azure") {
return `${this.apiBase}/completions?api-version=${this.apiVersion}`;
} else {
let url = this.apiBase;
if (!url) {
throw new Error(
"No API base URL provided. Please set the 'apiBase' option in config.json"
);
}
if (url.endsWith("/")) {
url = url.slice(0, -1);
}
if (!url.endsWith("/v1")) {
url += "/v1";
}
return url + "/completions";
}
}

}

export default SAPGenAIHub;
2 changes: 2 additions & 0 deletions core/llm/llms/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import OpenAI from "./OpenAI";
import Replicate from "./Replicate";
import TextGenWebUI from "./TextGenWebUI";
import Together from "./Together";
import SAPGenAIHub from "./SAPGenAIHub";

function convertToLetter(num: number): string {
let result = "";
Expand Down Expand Up @@ -88,6 +89,7 @@ const LLMs = [
Gemini,
Mistral,
Bedrock,
SAPGenAIHub,
DeepInfra,
];

Expand Down