Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Predict stop sequence matches during streaming #541

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

Y4hL
Copy link
Contributor

@Y4hL Y4hL commented Mar 6, 2024

When streaming using mlx_lm/server.py we should predict potential stop sequence matches, and generate tokens until we know that there is no match. This prevents the server from sending parts of a stop sequence to the client before it finds the match.

Fixes #524

My implementation adds a new function called "sequence_overlap" which checks how much sequence 1 has overlap with sequence 2. It checks for larger overlaps first, and returns the overlap as an integer.

The server checks for overlaps, and generates more tokens before allowing the server to send them.

if any((sequence_overlap(tokens, sequence) for sequence in stop_id_sequences)):
    continue

The sequence_overlap implementation can be tested with this example:

from typing import Sequence


def sequence_overlap(s1: Sequence, s2: Sequence) -> int:
    """
    Check how much overlap two sequences have.
        Only checks the end of s1 overlapping the start of s2

    Args:
        s1 (Sequence): The first sequence, which end is checked
        s2 (Sequence): The second sequence, which beginning is checked

    Returns:
        int: The amount of overlap between s1 and s2
    """
    # Count down from the length of the smaller list -> Checks for larger overlaps first
    for index in range(min(len(s1), len(s2)), 0, -1):
        # Check if they have index amount of overlap
        if s1[-index:] == s2[:index]:
            return index
    return 0


stop_sequence = [27, 28, 29]

tokens = []
new_tokens = []

for token in range(50):
    tokens.append(token)
    new_tokens.append(token)

    # This should always be the first check, since it needs to be performed on every token
    if len(tokens) >= len(stop_sequence) and tokens[-len(stop_sequence):] == stop_sequence:
        print("Contains stop sequence:", new_tokens)
        tokens = tokens[:len(tokens) - len(stop_sequence)]
        new_tokens.clear()
        break

    # Generate tokens until we know that tokens does not contain stop sequence
    if sequence_overlap(tokens, stop_sequence):
        print("Found a possible start to a stop sequence:", new_tokens)
        continue

    # Process new tokens
    print("Processing new tokens:", new_tokens)
    new_tokens.clear()

# In the case that the generation ends with the start of a stop sequence
# We need to process leftovers, since it would call continue until a break
if new_tokens:
    print("Processing leftover new tokens", new_tokens)
    new_tokens.clear()

print("Full sequence:", tokens)

@awni
Copy link
Member

awni commented Mar 6, 2024

Oh much better, thank you 😄

llms/mlx_lm/server.py Outdated Show resolved Hide resolved
@awni
Copy link
Member

awni commented Mar 14, 2024

What about adding your test (modified for unittest) as a test case to a new test file test_server.py in the tests directory: https://github.com/ml-explore/mlx-examples/tree/main/llms/tests ?

llms/mlx_lm/server.py Outdated Show resolved Hide resolved
@Y4hL
Copy link
Contributor Author

Y4hL commented Mar 14, 2024

Yeah, I'll have a look into writing a unittest

Check for overlap of stop sequences and the tokens array for potential sequence matches after more tokens get generated. Generate tokens until we can confirm that the stop sequence is not met.
Added a test for the sequence_overlap method
@awni
Copy link
Member

awni commented Mar 20, 2024

Hi! Is this ready to be reviewed again?

@Y4hL
Copy link
Contributor Author

Y4hL commented Mar 20, 2024

@awni yes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Using token sequences as stop criteria does not work in mlx_lm
2 participants