Skip to content

Commit

Permalink
Update package versions and add support for retrieving a specific num…
Browse files Browse the repository at this point in the history
…ber of documents
  • Loading branch information
n4ze3m committed Apr 26, 2024
1 parent e95ee8a commit efc922d
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 39 deletions.
2 changes: 1 addition & 1 deletion app/ui/package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "app",
"private": true,
"version": "1.8.1",
"version": "1.8.2",
"type": "module",
"scripts": {
"dev": "vite",
Expand Down
1 change: 1 addition & 0 deletions app/ui/src/@types/bot.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ export type BotSettings = {
public_id: string;
temperature: number;
embedding: string;
noOfDocumentsToRetrieve: number;
qaPrompt: string;
questionGeneratorPrompt: string;
streaming: boolean;
Expand Down
29 changes: 28 additions & 1 deletion app/ui/src/components/Bot/Settings/SettingsCard.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
import { Form, Input, notification, Select, Slider, Switch } from "antd";
import {
Form,
Input,
InputNumber,
notification,
Select,
Slider,
Switch,
} from "antd";
import { useNavigate, useParams } from "react-router-dom";
import api from "../../../services/api";
import { useMutation, useQueryClient } from "@tanstack/react-query";
Expand Down Expand Up @@ -155,6 +163,7 @@ export const SettingsCard: React.FC<BotSettings> = ({
bot_protect: data.bot_protect,
use_rag: data.use_rag,
bot_model_api_key: data.bot_model_api_key,
noOfDocumentsToRetrieve: data.noOfDocumentsToRetrieve,
}}
form={form}
requiredMark={false}
Expand Down Expand Up @@ -260,6 +269,24 @@ export const SettingsCard: React.FC<BotSettings> = ({
/>
</Form.Item>

<Form.Item
name="noOfDocumentsToRetrieve"
label="Number of documents to retrieve"
rules={[
{
required: true,
message:
"Please input a number of documents to retrieve!",
},
]}
>
<InputNumber
min={0}
style={{ width: "100%" }}
placeholder="Enter number of documents to retrieve"
/>
</Form.Item>

<Form.Item
label={"Question Answering Prompt (System Prompt)"}
name="qaPrompt"
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "dialoqbase",
"version": "1.8.1",
"version": "1.8.2",
"description": "Create chatbots with ease",
"scripts": {
"ui:dev": "pnpm run --filter ui dev",
Expand Down
2 changes: 2 additions & 0 deletions server/prisma/migrations/q_26/migration.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "Bot" ADD COLUMN "noOfDocumentsToRetrieve" INTEGER DEFAULT 4;
1 change: 1 addition & 0 deletions server/prisma/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ model Bot {
description String?
createdAt DateTime @default(now())
temperature Float @default(0.7)
noOfDocumentsToRetrieve Int? @default(4)
model String @default("gpt-3.5-turbo")
provider String @default("openai")
embedding String @default("openai")
Expand Down
3 changes: 3 additions & 0 deletions server/src/schema/api/v1/bot/bot/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ export const updateBotByIdSchema: FastifySchema = {
bot_model_api_key: {
type: "string",
},
noOfDocumentsToRetrieve: {
type: "number",
}
},
},
};
Expand Down
64 changes: 35 additions & 29 deletions server/src/utils/hybrid.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,47 +40,53 @@ export class DialoqbaseHybridRetrival extends BaseRetriever {
protected async similaritySearch(
query: string,
k: number,
_callbacks?: Callbacks,
_callbacks?: Callbacks
): Promise<SearchResult[]> {
try {
const embeddedQuery = await this.embeddings.embedQuery(query);

const embeddedQuery = await this.embeddings.embedQuery(query);
const vector = `[${embeddedQuery.join(",")}]`;
const bot_id = this.botId;

const vector = `[${embeddedQuery.join(",")}]`;
const bot_id = this.botId;

const data = await prisma.$queryRaw`
const data = await prisma.$queryRaw`
SELECT * FROM "similarity_search_v2"(query_embedding := ${vector}::vector, botId := ${bot_id}::text,match_count := ${k}::int)
`;

const result: [Document, number, number][] = (
data as SearchEmbeddingsResponse[]
).map((resp) => [
new Document({
metadata: resp.metadata,
pageContent: resp.content,
}),
resp.similarity * 10,
resp.id,
]);


return result;
} catch (e) {
console.log(e)
return []
}
const result: [Document, number, number][] = (
data as SearchEmbeddingsResponse[]
).map((resp) => [
new Document({
metadata: resp.metadata,
pageContent: resp.content,
}),
resp.similarity * 10,
resp.id,
]);

return result;
} catch (e) {
console.log(e);
return [];
}
}

protected async keywordSearch(
query: string,
k: number,
k: number
): Promise<SearchResult[]> {
const query_text = query;
const bot_id = this.botId;

const botInfo = await prisma.bot.findFirst({
where: {
id: bot_id,
},
});

const match_count = botInfo?.noOfDocumentsToRetrieve || k;

const data = await prisma.$queryRaw`
SELECT * FROM "kw_match_documents"(query_text := ${query_text}::text, bot_id := ${bot_id}::text,match_count := ${k}::int)
SELECT * FROM "kw_match_documents"(query_text := ${query_text}::text, bot_id := ${bot_id}::text,match_count := ${match_count}::int)
`;

const result: [Document, number, number][] = (
Expand All @@ -104,12 +110,12 @@ export class DialoqbaseHybridRetrival extends BaseRetriever {
query: string,
similarityK: number,
keywordK: number,
callbacks?: Callbacks,
callbacks?: Callbacks
): Promise<SearchResult[]> {
const similarity_search = this.similaritySearch(
query,
similarityK,
callbacks,
callbacks
);

const keyword_search = this.keywordSearch(query, keywordK);
Expand All @@ -136,13 +142,13 @@ export class DialoqbaseHybridRetrival extends BaseRetriever {

async _getRelevantDocuments(
query: string,
runManager?: CallbackManagerForRetrieverRun,
runManager?: CallbackManagerForRetrieverRun
): Promise<Document[]> {
const searchResults = await this.hybridSearch(
query,
this.similarityK,
this.keywordK,
runManager?.getChild("hybrid_search"),
runManager?.getChild("hybrid_search")
);

return searchResults.map(([doc]) => doc);
Expand Down
21 changes: 14 additions & 7 deletions server/src/utils/store.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ export class DialoqbaseVectorStore extends VectorStore {
if (row?.embedding) {
const vector = `[${row.embedding.join(",")}]`;
const content = row?.content.replace(/\x00/g, "").trim();
await prisma
.$executeRaw`INSERT INTO "BotDocument" ("content", "embedding", "metadata", "botId", "sourceId") VALUES (${content}, ${vector}::vector, ${row.metadata}, ${row.botId}, ${row.sourceId})`;
await prisma.$executeRaw`INSERT INTO "BotDocument" ("content", "embedding", "metadata", "botId", "sourceId") VALUES (${content}, ${vector}::vector, ${row.metadata}, ${row.botId}, ${row.sourceId})`;
}
});
} catch (e) {
Expand All @@ -57,7 +56,7 @@ export class DialoqbaseVectorStore extends VectorStore {
static async fromDocuments(
docs: Document[],
embeddings: Embeddings,
dbConfig: DialoqbaseLibArgs,
dbConfig: DialoqbaseLibArgs
) {
const instance = new this(embeddings, dbConfig);
await instance.addDocuments(docs);
Expand All @@ -68,7 +67,7 @@ export class DialoqbaseVectorStore extends VectorStore {
texts: string[],
metadatas: object[] | object,
embeddings: Embeddings,
dbConfig: DialoqbaseLibArgs,
dbConfig: DialoqbaseLibArgs
) {
const docs = [];
for (let i = 0; i < texts.length; i += 1) {
Expand All @@ -84,7 +83,7 @@ export class DialoqbaseVectorStore extends VectorStore {

static async fromExistingIndex(
embeddings: Embeddings,
dbConfig: DialoqbaseLibArgs,
dbConfig: DialoqbaseLibArgs
) {
const instance = new this(embeddings, dbConfig);
return instance;
Expand All @@ -93,14 +92,22 @@ export class DialoqbaseVectorStore extends VectorStore {
async similaritySearchVectorWithScore(
query: number[],
k: number,
filter?: this["FilterType"] | undefined,
filter?: this["FilterType"] | undefined
): Promise<[Document<Record<string, any>>, number][]> {
console.log(this.botId);
const vector = `[${query.join(",")}]`;
const bot_id = this.botId;

const botInfo = await prisma.bot.findFirst({
where: {
id: bot_id,
},
});

const match_count = botInfo?.noOfDocumentsToRetrieve || k;

const data = await prisma.$queryRaw`
SELECT * FROM "similarity_search_v2"(query_embedding := ${vector}::vector, botId := ${bot_id}::text,match_count := ${k}::int)
SELECT * FROM "similarity_search_v2"(query_embedding := ${vector}::vector, botId := ${bot_id}::text,match_count := ${match_count}::int)
`;

const result: [Document, number][] = (
Expand Down

0 comments on commit efc922d

Please sign in to comment.