Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(langchain4j-milvus): MilvusEmbeddingStore supports configure required index parameters #931

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package dev.langchain4j.store.embedding.milvus;

import dev.langchain4j.store.embedding.milvus.parameter.IndexParam;
import io.milvus.client.MilvusServiceClient;
import io.milvus.common.clientenum.ConsistencyLevelEnum;
import io.milvus.grpc.FlushResponse;
Expand Down Expand Up @@ -81,12 +82,14 @@ static void dropCollection(MilvusServiceClient milvusClient, String collectionNa
static void createIndex(MilvusServiceClient milvusClient,
String collectionName,
IndexType indexType,
IndexParam indexParam,
MetricType metricType) {

CreateIndexParam request = CreateIndexParam.newBuilder()
.withCollectionName(collectionName)
.withFieldName(VECTOR_FIELD_NAME)
.withIndexType(indexType)
.withExtraParam(indexParam.toExtraParam())
.withMetricType(metricType)
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static dev.langchain4j.internal.ValidationUtils.ensureTrue;
import static dev.langchain4j.store.embedding.milvus.CollectionOperationsExecutor.*;
import static dev.langchain4j.store.embedding.milvus.CollectionRequestBuilder.buildSearchRequest;
import static dev.langchain4j.store.embedding.milvus.Generator.generateRandomIds;
Expand All @@ -17,10 +18,11 @@
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.filter.Filter;
import dev.langchain4j.store.embedding.milvus.parameter.IndexParam;
import io.milvus.client.MilvusServiceClient;
import io.milvus.common.clientenum.ConsistencyLevelEnum;
import io.milvus.param.ConnectParam;
Expand Down Expand Up @@ -59,6 +61,7 @@ public MilvusEmbeddingStore(
String collectionName,
Integer dimension,
IndexType indexType,
IndexParam indexParam,
MetricType metricType,
String uri,
String token,
Expand Down Expand Up @@ -87,8 +90,17 @@ public MilvusEmbeddingStore(
this.retrieveEmbeddingsOnSearch = getOrDefault(retrieveEmbeddingsOnSearch, false);

if (!hasCollection(milvusClient, this.collectionName)) {
indexType = getOrDefault(indexType, FLAT);
if (indexParam == null) {
if (IndexParam.isIndexParamNullable(indexType)) {
indexParam = IndexParam.EMPTY_INSTANCE;
}
}
ensureNotNull(indexParam, "IndexParam is required for indexType " + indexType);
ensureTrue(indexParam.support(indexType), String.format("IndexParam %s does not support IndexType %s", indexParam.getClass(), indexType));
// validate IndexParam before creating the collection to prevent exceptions caused by invalid indices
createCollection(milvusClient, this.collectionName, ensureNotNull(dimension, "dimension"));
createIndex(milvusClient, this.collectionName, getOrDefault(indexType, FLAT), this.metricType);
createIndex(milvusClient, this.collectionName, indexType, indexParam, this.metricType);
}

loadCollectionInMemory(milvusClient, collectionName);
Expand Down Expand Up @@ -185,6 +197,7 @@ public static class Builder {
private String collectionName;
private Integer dimension;
private IndexType indexType;
private IndexParam indexParam;
private MetricType metricType;
private String uri;
private String token;
Expand Down Expand Up @@ -245,6 +258,17 @@ public Builder indexType(IndexType indexType) {
return this;
}

/**
* This parameter is required except for indexType {@link IndexType#FLAT} and {@link IndexType#BIN_FLAT}.
*
* @param indexParam The parameters of the index.
* @return builder
*/
public Builder indexParam(IndexParam indexParam) {
this.indexParam = indexParam;
return this;
}

/**
* @param metricType The type of the metric used for similarity search.
* Default value: COSINE.
Expand Down Expand Up @@ -332,6 +356,7 @@ public MilvusEmbeddingStore build() {
collectionName,
dimension,
indexType,
indexParam,
metricType,
uri,
token,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package dev.langchain4j.store.embedding.milvus.parameter;

import io.milvus.param.IndexType;

/**
* for more information, see <a href="https://milvus.io/docs/index.md#BIN_FLAT">Index#BIN_FLAT</a>
*/
public class BinFlatIndexParam extends IndexParam{
public BinFlatIndexParam() {
super(IndexType.BIN_FLAT);
}

public static Builder builder() {
return new Builder();
}

public static final class Builder {
public Builder() {
}

public BinFlatIndexParam build() {
return new BinFlatIndexParam();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package dev.langchain4j.store.embedding.milvus.parameter;

import io.milvus.param.IndexType;

import static dev.langchain4j.internal.ValidationUtils.ensureBetween;

/**
* for more information, see <a href="https://milvus.io/docs/index.md#BIN_IVF_FLAT">Index#BIN_IVF_FLAT</a>
*/
public class BinIvfFlatIndexParam extends IndexParam {

/**
* Number of cluster units, Range: [1, 65536]
*/
private final Integer nlist;

public BinIvfFlatIndexParam(Integer nlist) {
super(IndexType.BIN_IVF_FLAT);
this.nlist = nlist;
ensureBetween(nlist, 1, 65536, "nlist must be in range [1,65536]");
}

public static Builder builder() {
return new Builder();
}

public Integer getNlist() {
return nlist;
}


public static final class Builder {
private Integer nlist;

public Builder() {
}

public Builder nlist(Integer nlist) {
this.nlist = nlist;
return this;
}


public BinIvfFlatIndexParam build() {
return new BinIvfFlatIndexParam(nlist);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package dev.langchain4j.store.embedding.milvus.parameter;

import io.milvus.param.IndexType;

/**
* for more information, see <a href="https://milvus.io/docs/disk_index.md#Index-and-search-settings">Disk Index#DISKANN</a>
*/
public class DiskannIndexParam extends IndexParam {
public DiskannIndexParam() {
super(IndexType.DISKANN);
}

public static Builder builder() {
return new Builder();
}

public static final class Builder {
public Builder() {
}

public DiskannIndexParam build() {
return new DiskannIndexParam();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package dev.langchain4j.store.embedding.milvus.parameter;

import io.milvus.param.IndexType;

/**
* empty placeholder class
* for more information, see <a href="https://milvus.io/docs/index.md#FLAT">Index#FLAT</a>
*/
public class FlatIndexParam extends IndexParam {
public FlatIndexParam() {
super(IndexType.FLAT);
}

public static Builder builder() {
return new Builder();
}

public static final class Builder {
public Builder() {
}

public FlatIndexParam build() {
return new FlatIndexParam();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package dev.langchain4j.store.embedding.milvus.parameter;

import io.milvus.param.IndexType;

import static dev.langchain4j.internal.ValidationUtils.ensureBetween;

/**
* for more information, see <a href="https://milvus.io/docs/index-with-gpu.md#Prepare-index-parameters">GPU Index</a>
* parameter same as <a href="https://milvus.io/docs/index.md#IVF_FLAT">Index#IVF_FLAT</a>
*/
public class GpuIvfFlatIndexParam extends IndexParam {
/**
* Number of cluster units, Range: [1, 65536]
*/
private final Integer nlist;

public GpuIvfFlatIndexParam(Integer nlist) {
super(IndexType.GPU_IVF_FLAT);
ensureBetween(nlist, 1, 65536, "nlist must be in range [1,65536]");
this.nlist = nlist;
}

public Integer getNlist() {
return nlist;
}

public static final class Builder {
private Integer nlist;

public Builder() {
}

public Builder nlist(Integer nlist) {
this.nlist = nlist;
return this;
}

public GpuIvfFlatIndexParam build() {
return new GpuIvfFlatIndexParam(nlist);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package dev.langchain4j.store.embedding.milvus.parameter;

import io.milvus.param.IndexType;

import static dev.langchain4j.internal.ValidationUtils.ensureBetween;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;

/**
* for more information, see <a href="https://milvus.io/docs/index-with-gpu.md#Prepare-index-parameters">GPU Index</a>
* parameter same as <a href="https://milvus.io/docs/index.md#IVF_PQ">Index#IVF_PQ</a>
*/
public class GpuIvfPqIndexParam extends IndexParam {

/**
* Number of cluster units, Range: [1, 65536]
*/
private final Integer nlist;
/**
* Number of factors of product quantization, Range: dim mod m == 0
*/
private final Integer m;
/**
* [Optional] Number of bits in which each low-dimensional vector is stored. Range: [1, 16], Default: 8
*/
private final Integer nbits;

public GpuIvfPqIndexParam(Integer nlist, Integer m) {
this(nlist, m, 8);
}

public GpuIvfPqIndexParam(Integer nlist, Integer m, Integer nbits) {
super(IndexType.GPU_IVF_PQ);
ensureBetween(nlist, 1, 65536, "nlist must be between in range [1,65536]");
ensureNotNull(m, "m must not be null, value range is dim mod m == 0");
if (nbits != null) {
ensureBetween(nbits, 1, 16, "nbits must be in rnage [1,16]");
}
this.nlist = nlist;
this.m = m;
this.nbits = nbits;
}

public static Builder builder() {
return new Builder();
}

public Integer getNlist() {
return nlist;
}

public Integer getM() {
return m;
}

public Integer getNbits() {
return nbits;
}

public static final class Builder {
private Integer nbits;
private Integer m;
private Integer nlist = 8;

public Builder() {
}

public Builder nbits(Integer nbits) {
this.nbits = nbits;
return this;
}

public Builder m(Integer m) {
this.m = m;
return this;
}

public Builder nlist(Integer nlist) {
this.nlist = nlist;
return this;
}

public GpuIvfPqIndexParam build() {
return new GpuIvfPqIndexParam(nlist, m, nbits);
}
}
}