Skip to content

Commit

Permalink
CosineSimilarity protection from zeros; full coverage tests. (#392)
Browse files Browse the repository at this point in the history
Added `max(sqrt(normA), sqrt(normB)), EPSILON)` to insulate from zero
vectors. Tested error results of embeddings of different sizes.
  • Loading branch information
crutcher committed Dec 30, 2023
1 parent d673223 commit e2ba220
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;

public class CosineSimilarity {
public static final float EPSILON = 1e-8f;

/**
* Calculates cosine similarity between two vectors.
Expand All @@ -20,6 +21,9 @@ public class CosineSimilarity {
* 1 indicates vectors are pointing in the same direction (but not necessarily of the same magnitude).
* <p>
* Not to be confused with cosine distance ([0..2]), which quantifies how different two vectors are.
* <p>
* Embeddings of all-zeros vectors are considered orthogonal to all other vectors;
* including other all-zeros vectors.
*
* @param embeddingA first embedding vector
* @param embeddingB second embedding vector
Expand Down Expand Up @@ -47,7 +51,8 @@ public static double between(Embedding embeddingA, Embedding embeddingB) {
normB += vectorB[i] * vectorB[i];
}

return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
// Avoid division by zero.
return dotProduct / Math.max(Math.sqrt(normA) * Math.sqrt(normB), EPSILON);
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,34 @@
package dev.langchain4j.store.embedding;

import dev.langchain4j.data.embedding.Embedding;
import org.assertj.core.api.WithAssertions;
import org.junit.jupiter.api.Test;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.data.Percentage.withPercentage;

class CosineSimilarityTest {
class CosineSimilarityTest implements WithAssertions {
@Test
public void test_bad() {
Embedding embeddingA = Embedding.from(new float[]{1, 1, 1});
Embedding embeddingB = Embedding.from(new float[]{1, 1, 1, 1});

assertThatExceptionOfType(IllegalArgumentException.class)
.isThrownBy(() -> CosineSimilarity.between(embeddingA, embeddingB))
.withMessage("Length of vector a (3) must be equal to the length of vector b (4)");
}

@Test
public void test_zeros() {
Embedding embeddingA = Embedding.from(new float[]{0, 0, 0});
Embedding embeddingB = Embedding.from(new float[]{0, 0, 0});

assertThat(CosineSimilarity.between(embeddingA, embeddingB)).isCloseTo(0, withPercentage(1));
}

@Test
void should_calculate_cosine_similarity() {
Embedding embeddingA = Embedding.from(new float[]{1, 1, 1});
Embedding embeddingB = Embedding.from(new float[]{-1, -1, -1});
Embedding embeddingA = Embedding.from(new float[]{1, -1, 1});
Embedding embeddingB = Embedding.from(new float[]{-1, 1, -1});

assertThat(CosineSimilarity.between(embeddingA, embeddingA)).isCloseTo(1, withPercentage(1));
assertThat(CosineSimilarity.between(embeddingA, embeddingB)).isCloseTo(-1, withPercentage(1));
Expand Down

0 comments on commit e2ba220

Please sign in to comment.