Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Add binary & scalar embedding quantization support to Transformers.js #681

Closed
jonathanpv opened this issue Apr 4, 2024 · 7 comments · Fixed by #691 · May be fixed by #683
Closed

[feat] Add binary & scalar embedding quantization support to Transformers.js #681

jonathanpv opened this issue Apr 4, 2024 · 7 comments · Fixed by #691 · May be fixed by #683
Labels
enhancement New feature or request

Comments

@jonathanpv
Copy link
Contributor

jonathanpv commented Apr 4, 2024

Feature request

Add binary & scalar quantization support

We should extract the algorithm to quantize the embeddings from the PR below from sentence-transformers and add it transformers.js so the feature-extraction pipeline can support binary vector search

Either adding a quantize-output or binary-output to the pipeline or we can have helper method that quantizes the tensor so this solution can be applied to other parts of the codebase.

UKPLab/sentence-transformers#2549

Motivation

Given performance gains for binary vector embeddings additional quantization helper methods can be useful for client-side vector search to reduce memory footprint

Your contribution

I plan on making a PR, but note, I am slow to open source world still and learning the transformers.js repo as fast as I can.

@xenova may be able to more quickly parse the repo from above and add the methods along with test cases in the style of the repo quicker than me.

However I am working on a poc for using vector embeddings on a side project so I can contribute that work and repo or readme when I finish it.

@jonathanpv jonathanpv added the enhancement New feature or request label Apr 4, 2024
@jonathanpv
Copy link
Contributor Author

Seems like numpy equivalent in js doesn't exist

I wonder how transformers.js does its math calculations

@jonathanpv
Copy link
Contributor Author

jonathanpv commented Apr 4, 2024

have we considered adding numjs to transformers.js so the translation could be more 1:1 or do we prefer pulling in math functions on a as-needed basis?

@xenova
Copy link
Owner

xenova commented Apr 4, 2024

Hi there 👋 I have already worked on this a bit, and might be a useful addition to the feature-extraction pipeline.

Here's some example code which shows how you can achieve this in javascript:

import { pipeline, Tensor } from "@xenova/transformers";

function hamming_distance(arr1, arr2) {
    if (arr1.length !== arr2.length) {
        throw new Error("Typed arrays must have the same length");
    }

    let distance = 0;

    // Iterate over each byte in the typed arrays
    for (let i = 0; i < arr1.length; ++i) {
        // XOR the bytes to find differing bits
        let xorResult = arr1[i] ^ arr2[i];

        // Count set bits in the XOR result using Brian Kernighan's Algorithm
        while (xorResult) {
            ++distance;
            xorResult &= xorResult - 1;
        }
    }

    return distance;
}

function quantize_embeddings(tensor, precision) {
    if (tensor.dims.length !== 2) {
        throw new Error("The tensor must have 2 dimensions");
    }
    if (tensor.dims.at(-1) % 8 !== 0) {
        throw new Error("The last dimension of the tensor must be a multiple of 8");
    }
    if (!['binary', 'ubinary'].includes(precision)) {
        throw new Error("The precision must be either 'binary' or 'ubinary'");
    }
    // Create a typed array to store the packed bits
    const inputData = tensor.data;

    const signed = precision === 'binary';
    const cls = signed ? Int8Array : Uint8Array;
    const dtype = signed ? 'int8' : 'uint8';
    const outputData = new cls(inputData.length / 8);

    // Iterate over each number in the array
    for (let i = 0; i < inputData.length; ++i) {
        // Determine if the number is greater than 0
        const bit = inputData[i] > 0 ? 1 : 0;

        // Calculate the index in the typed array and the position within the byte
        const arrayIndex = Math.floor(i / 8);
        const bitPosition = i % 8;

        // Pack the bit into the typed array
        outputData[arrayIndex] |= bit << (7 - bitPosition);
        if (signed && bitPosition === 0) {
            outputData[arrayIndex] -= 128;
        }
    };

    return new Tensor(dtype, outputData, [tensor.dims[0], tensor.dims[1] / 8]);
}

const embedder = await pipeline('feature-extraction', 'Xenova/all-MiniLM-L6-v2', {
    quantized: false,
});

const texts = ['hello', 'hi', 'banana'];
const output = await embedder(texts, {
    normalize: true,
    pooling: 'mean',
});
const embeddings = quantize_embeddings(output, 'ubinary').tolist();

const pairs = [[0, 1], [0, 2], [1, 2]];
for (const [i, j] of pairs) {
    console.log(`${texts[i]} <-> ${texts[j]}`, '|', hamming_distance(embeddings[i], embeddings[j]));
}

outputs:

hello <-> hi | 86
hello <-> banana | 165
hi <-> banana | 163

indicating higher similarity between hello and hi than hello and banana, for example

@jonathanpv
Copy link
Contributor Author

@xenova @ashvardanian

So when I was testing the code I went ahead and made a front end to play around with it

I noticed the embeddings ended up being int8 and was wondering if that was as-designed from the quantize_embeddings function despite choosing ubinary or binary

Curious how the algo works under-the-hood. I tried learning it from reading the code but couldn't understand it

image

@jonathanpv
Copy link
Contributor Author

I can separate out the commits so it's easier to parse, just realized how large the changes diff was lol 5k lines

@ashvardanian
Copy link

@jonathanpv, I find it reasonable that the outputs are 8-bit integers, as long as the cosine distance between them makes sense. Just make sure to use a proper library to compute those, as NumPy and SciPy don't support mixed precision and will overflow. SimSIMD should work fine and is available for JS as well 😉

@jonathanpv
Copy link
Contributor Author

@jonathanpv, I find it reasonable that the outputs are 8-bit integers, as long as the cosine distance between them makes sense. Just make sure to use a proper library to compute those, as NumPy and SciPy don't support mixed precision and will overflow. SimSIMD should work fine and is available for JS as well 😉

Oh wow, didn't realize simsimd also had these functions, I can see why you implemented them. I'll try this out after I finish this other project. Will be nice for me to benchmark in the browser context.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
3 participants